mirror of
https://github.com/databasus/databasus.git
synced 2026-04-06 00:32:03 +02:00
Compare commits
141 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1f1d80245f | ||
|
|
16a29cf458 | ||
|
|
43e04500ac | ||
|
|
cee3022f85 | ||
|
|
f46d92c480 | ||
|
|
10677238d7 | ||
|
|
2553203fcf | ||
|
|
7b05bd8000 | ||
|
|
8d45728f73 | ||
|
|
c70ad82c95 | ||
|
|
e4bc34d319 | ||
|
|
257ae85da7 | ||
|
|
b42c820bb2 | ||
|
|
da5c13fb11 | ||
|
|
35180360e5 | ||
|
|
e4f6cd7a5d | ||
|
|
d7b8e6d56a | ||
|
|
6016f23fb2 | ||
|
|
e7c4ee8f6f | ||
|
|
a75702a01b | ||
|
|
81a21eb907 | ||
|
|
33d6bf0147 | ||
|
|
6eb53bb07b | ||
|
|
6ac04270b9 | ||
|
|
b0510d7c21 | ||
|
|
dc5f271882 | ||
|
|
8f718771c9 | ||
|
|
d8eea05dca | ||
|
|
b2a94274d7 | ||
|
|
77c2712ebb | ||
|
|
a9dc29f82c | ||
|
|
c934a45dca | ||
|
|
d4acdf2826 | ||
|
|
49753c4fc0 | ||
|
|
c6aed6b36d | ||
|
|
3060b4266a | ||
|
|
ebeb597f17 | ||
|
|
4783784325 | ||
|
|
bd41433bdb | ||
|
|
a9073787d2 | ||
|
|
0890bf8f09 | ||
|
|
f8c11e8802 | ||
|
|
e798d82fc1 | ||
|
|
81a01585ee | ||
|
|
a8465c1a10 | ||
|
|
a9e5db70f6 | ||
|
|
7a47be6ca6 | ||
|
|
16be3db0c6 | ||
|
|
744e51d1e1 | ||
|
|
b3af75d430 | ||
|
|
6f7320abeb | ||
|
|
a1655d35a6 | ||
|
|
9b6e801184 | ||
|
|
105777ab6f | ||
|
|
3a1a88d5cf | ||
|
|
699ca16814 | ||
|
|
26f3cf233a | ||
|
|
3d8372e9f6 | ||
|
|
b46f11804d | ||
|
|
4676361688 | ||
|
|
de3679cadf | ||
|
|
8f03a30af2 | ||
|
|
356529c58a | ||
|
|
e7eed056f7 | ||
|
|
6084cdc954 | ||
|
|
c50bcc57b1 | ||
|
|
ea76300ed7 | ||
|
|
9b413e4076 | ||
|
|
f91cb260f2 | ||
|
|
8f37a8082f | ||
|
|
5cf7614772 | ||
|
|
ae27f74c2e | ||
|
|
9457516bb9 | ||
|
|
a36fc5bf8c | ||
|
|
03ada5806d | ||
|
|
a6675390e5 | ||
|
|
af2f978876 | ||
|
|
04e7eba5c5 | ||
|
|
520165541d | ||
|
|
5b556bc161 | ||
|
|
0952a15ec5 | ||
|
|
1afb3aa3ff | ||
|
|
19b92e5f74 | ||
|
|
d4763f26b2 | ||
|
|
0e389ba16b | ||
|
|
594a3294c6 | ||
|
|
4e4a323cf1 | ||
|
|
7d9ecf697b | ||
|
|
755c420157 | ||
|
|
ff73627287 | ||
|
|
9c9ab00ace | ||
|
|
7366e21a1a | ||
|
|
a327d1aa57 | ||
|
|
f152b16ea3 | ||
|
|
85dbe80d3d | ||
|
|
edf4028fd1 | ||
|
|
8d85c45a90 | ||
|
|
d9c176d19a | ||
|
|
7a6f72a456 | ||
|
|
9a1471b88b | ||
|
|
386ea1d708 | ||
|
|
a4b23936ee | ||
|
|
b36aa9d48b | ||
|
|
13cb8e5bd2 | ||
|
|
2db4b6e075 | ||
|
|
f2b0b2bf1f | ||
|
|
7142ce295e | ||
|
|
04621b9b2d | ||
|
|
bd329a68cf | ||
|
|
f957abc9db | ||
|
|
c0fd6be1a9 | ||
|
|
c39bd34d5e | ||
|
|
27bec15a29 | ||
|
|
d98baa0656 | ||
|
|
4344f5ea5e | ||
|
|
7c6afa5b88 | ||
|
|
dbac799e1b | ||
|
|
7ee3817089 | ||
|
|
bae6f7f007 | ||
|
|
55dc087ddd | ||
|
|
c94d0db637 | ||
|
|
a1adef2261 | ||
|
|
4602dc3f88 | ||
|
|
cbbfc5ea8f | ||
|
|
dd1072e230 | ||
|
|
a495e5317a | ||
|
|
7eed647038 | ||
|
|
6973241e25 | ||
|
|
ab181f5b81 | ||
|
|
b60a0cc170 | ||
|
|
f319a497b3 | ||
|
|
bc870b3f8e | ||
|
|
15383c59eb | ||
|
|
d14c223a65 | ||
|
|
2c0a294027 | ||
|
|
5d851d73bd | ||
|
|
699913c251 | ||
|
|
a2e3f30a6d | ||
|
|
80f1174ecd | ||
|
|
a47f8d5e2c | ||
|
|
54b9e67656 |
315
.github/workflows/ci-release.yml
vendored
315
.github/workflows/ci-release.yml
vendored
@@ -9,25 +9,26 @@ on:
|
||||
|
||||
jobs:
|
||||
lint-backend:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
container:
|
||||
image: golang:1.24.9
|
||||
volumes:
|
||||
- /runner-cache/go-pkg:/go/pkg/mod
|
||||
- /runner-cache/go-build:/root/.cache/go-build
|
||||
- /runner-cache/golangci-lint:/root/.cache/golangci-lint
|
||||
- /runner-cache/apt-archives:/var/cache/apt/archives
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.24.9"
|
||||
- name: Configure Git for container
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
~/.cache/go-build
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('backend/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-
|
||||
- name: Download Go modules
|
||||
run: |
|
||||
cd backend
|
||||
go mod download
|
||||
|
||||
- name: Install golangci-lint
|
||||
run: |
|
||||
@@ -63,8 +64,6 @@ jobs:
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "20"
|
||||
cache: "npm"
|
||||
cache-dependency-path: frontend/package-lock.json
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
@@ -82,6 +81,11 @@ jobs:
|
||||
cd frontend
|
||||
npm run lint
|
||||
|
||||
- name: Build frontend
|
||||
run: |
|
||||
cd frontend
|
||||
npm run build
|
||||
|
||||
test-frontend:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint-frontend]
|
||||
@@ -93,8 +97,6 @@ jobs:
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "20"
|
||||
cache: "npm"
|
||||
cache-dependency-path: frontend/package-lock.json
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
@@ -107,44 +109,32 @@ jobs:
|
||||
npm run test
|
||||
|
||||
test-backend:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
needs: [lint-backend]
|
||||
container:
|
||||
image: golang:1.24.9
|
||||
options: --privileged -v /var/run/docker.sock:/var/run/docker.sock --add-host=host.docker.internal:host-gateway
|
||||
volumes:
|
||||
- /runner-cache/go-pkg:/go/pkg/mod
|
||||
- /runner-cache/go-build:/root/.cache/go-build
|
||||
- /runner-cache/apt-archives:/var/cache/apt/archives
|
||||
steps:
|
||||
- name: Free up disk space
|
||||
- name: Install Docker CLI
|
||||
run: |
|
||||
echo "Disk space before cleanup:"
|
||||
df -h
|
||||
# Remove unnecessary pre-installed software
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo rm -rf /usr/local/share/boost
|
||||
sudo rm -rf /usr/share/swift
|
||||
# Clean apt cache
|
||||
sudo apt-get clean
|
||||
# Clean docker images (if any pre-installed)
|
||||
docker system prune -af --volumes || true
|
||||
echo "Disk space after cleanup:"
|
||||
df -h
|
||||
apt-get update -qq
|
||||
apt-get install -y -qq docker.io docker-compose netcat-openbsd wget
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.24.9"
|
||||
- name: Configure Git for container
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
~/.cache/go-build
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('backend/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-
|
||||
- name: Download Go modules
|
||||
run: |
|
||||
cd backend
|
||||
go mod download
|
||||
|
||||
- name: Create .env file for testing
|
||||
run: |
|
||||
@@ -156,14 +146,16 @@ jobs:
|
||||
DEV_DB_PASSWORD=Q1234567
|
||||
#app
|
||||
ENV_MODE=development
|
||||
# db
|
||||
DATABASE_DSN=host=localhost user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
|
||||
DATABASE_URL=postgres://postgres:Q1234567@localhost:5437/databasus?sslmode=disable
|
||||
# db - using 172.17.0.1 to access host from container
|
||||
DATABASE_DSN=host=172.17.0.1 user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
|
||||
DATABASE_URL=postgres://postgres:Q1234567@172.17.0.1:5437/databasus?sslmode=disable
|
||||
# migrations
|
||||
GOOSE_DRIVER=postgres
|
||||
GOOSE_DBSTRING=postgres://postgres:Q1234567@localhost:5437/databasus?sslmode=disable
|
||||
GOOSE_DBSTRING=postgres://postgres:Q1234567@172.17.0.1:5437/databasus?sslmode=disable
|
||||
GOOSE_MIGRATION_DIR=./migrations
|
||||
# testing
|
||||
# testing
|
||||
TEST_LOCALHOST=172.17.0.1
|
||||
IS_SKIP_EXTERNAL_RESOURCES_TESTS=true
|
||||
# to get Google Drive env variables: add storage in UI and copy data from added storage here
|
||||
TEST_GOOGLE_DRIVE_CLIENT_ID=${{ secrets.TEST_GOOGLE_DRIVE_CLIENT_ID }}
|
||||
TEST_GOOGLE_DRIVE_CLIENT_SECRET=${{ secrets.TEST_GOOGLE_DRIVE_CLIENT_SECRET }}
|
||||
@@ -221,12 +213,14 @@ jobs:
|
||||
TEST_MONGODB_60_PORT=27060
|
||||
TEST_MONGODB_70_PORT=27070
|
||||
TEST_MONGODB_82_PORT=27082
|
||||
# Valkey (cache)
|
||||
VALKEY_HOST=localhost
|
||||
# Valkey (cache) - using 172.17.0.1
|
||||
VALKEY_HOST=172.17.0.1
|
||||
VALKEY_PORT=6379
|
||||
VALKEY_USERNAME=
|
||||
VALKEY_PASSWORD=
|
||||
VALKEY_IS_SSL=false
|
||||
# Host for test databases (container -> host)
|
||||
TEST_DB_HOST=172.17.0.1
|
||||
EOF
|
||||
|
||||
- name: Start test containers
|
||||
@@ -244,25 +238,25 @@ jobs:
|
||||
timeout 60 bash -c 'until docker exec dev-valkey valkey-cli ping 2>/dev/null | grep -q PONG; do sleep 2; done'
|
||||
echo "Valkey is ready!"
|
||||
|
||||
# Wait for test databases
|
||||
timeout 60 bash -c 'until nc -z localhost 5000; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5001; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5002; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5003; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5004; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5005; do sleep 2; done'
|
||||
# Wait for test databases (using 172.17.0.1 from container)
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5000; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5001; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5002; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5003; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5004; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5005; do sleep 2; done'
|
||||
|
||||
# Wait for MinIO
|
||||
timeout 60 bash -c 'until nc -z localhost 9000; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 9000; do sleep 2; done'
|
||||
|
||||
# Wait for Azurite
|
||||
timeout 60 bash -c 'until nc -z localhost 10000; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 10000; do sleep 2; done'
|
||||
|
||||
# Wait for FTP
|
||||
timeout 60 bash -c 'until nc -z localhost 7007; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 7007; do sleep 2; done'
|
||||
|
||||
# Wait for SFTP
|
||||
timeout 60 bash -c 'until nc -z localhost 7008; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 7008; do sleep 2; done'
|
||||
|
||||
# Wait for MySQL containers
|
||||
echo "Waiting for MySQL 5.7..."
|
||||
@@ -321,67 +315,66 @@ jobs:
|
||||
mkdir -p databasus-data/backups
|
||||
mkdir -p databasus-data/temp
|
||||
|
||||
- name: Cache PostgreSQL client tools
|
||||
id: cache-postgres
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: /usr/lib/postgresql
|
||||
key: postgres-clients-12-18-v1
|
||||
|
||||
- name: Cache MySQL client tools
|
||||
id: cache-mysql
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: backend/tools/mysql
|
||||
key: mysql-clients-57-80-84-9-v1
|
||||
|
||||
- name: Cache MariaDB client tools
|
||||
id: cache-mariadb
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: backend/tools/mariadb
|
||||
key: mariadb-clients-106-121-v1
|
||||
|
||||
- name: Cache MongoDB Database Tools
|
||||
id: cache-mongodb
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: backend/tools/mongodb
|
||||
key: mongodb-database-tools-100.10.0-v1
|
||||
|
||||
- name: Install MySQL dependencies
|
||||
- name: Install database client dependencies
|
||||
run: |
|
||||
sudo apt-get update -qq
|
||||
sudo apt-get install -y -qq libncurses6
|
||||
sudo ln -sf /usr/lib/x86_64-linux-gnu/libncurses.so.6 /usr/lib/x86_64-linux-gnu/libncurses.so.5
|
||||
sudo ln -sf /usr/lib/x86_64-linux-gnu/libtinfo.so.6 /usr/lib/x86_64-linux-gnu/libtinfo.so.5
|
||||
apt-get update -qq
|
||||
apt-get install -y -qq libncurses6 libpq5
|
||||
ln -sf /usr/lib/x86_64-linux-gnu/libncurses.so.6 /usr/lib/x86_64-linux-gnu/libncurses.so.5 || true
|
||||
ln -sf /usr/lib/x86_64-linux-gnu/libtinfo.so.6 /usr/lib/x86_64-linux-gnu/libtinfo.so.5 || true
|
||||
|
||||
- name: Install PostgreSQL, MySQL, MariaDB and MongoDB client tools
|
||||
if: steps.cache-postgres.outputs.cache-hit != 'true' || steps.cache-mysql.outputs.cache-hit != 'true' || steps.cache-mariadb.outputs.cache-hit != 'true' || steps.cache-mongodb.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
chmod +x backend/tools/download_linux.sh
|
||||
cd backend/tools
|
||||
./download_linux.sh
|
||||
|
||||
- name: Setup PostgreSQL symlinks (when using cache)
|
||||
if: steps.cache-postgres.outputs.cache-hit == 'true'
|
||||
- name: Setup PostgreSQL, MySQL and MariaDB client tools from pre-built assets
|
||||
run: |
|
||||
cd backend/tools
|
||||
mkdir -p postgresql
|
||||
|
||||
# Create directory structure
|
||||
mkdir -p postgresql mysql mariadb mongodb/bin
|
||||
|
||||
# Copy PostgreSQL client tools (12-18) from pre-built assets
|
||||
for version in 12 13 14 15 16 17 18; do
|
||||
version_dir="postgresql/postgresql-$version"
|
||||
mkdir -p "$version_dir/bin"
|
||||
pg_bin_dir="/usr/lib/postgresql/$version/bin"
|
||||
if [ -d "$pg_bin_dir" ]; then
|
||||
ln -sf "$pg_bin_dir/pg_dump" "$version_dir/bin/pg_dump"
|
||||
ln -sf "$pg_bin_dir/pg_dumpall" "$version_dir/bin/pg_dumpall"
|
||||
ln -sf "$pg_bin_dir/psql" "$version_dir/bin/psql"
|
||||
ln -sf "$pg_bin_dir/pg_restore" "$version_dir/bin/pg_restore"
|
||||
ln -sf "$pg_bin_dir/createdb" "$version_dir/bin/createdb"
|
||||
ln -sf "$pg_bin_dir/dropdb" "$version_dir/bin/dropdb"
|
||||
fi
|
||||
mkdir -p postgresql/postgresql-$version
|
||||
cp -r ../../assets/tools/x64/postgresql/postgresql-$version/bin postgresql/postgresql-$version/
|
||||
done
|
||||
|
||||
# Copy MySQL client tools (5.7, 8.0, 8.4, 9) from pre-built assets
|
||||
for version in 5.7 8.0 8.4 9; do
|
||||
mkdir -p mysql/mysql-$version
|
||||
cp -r ../../assets/tools/x64/mysql/mysql-$version/bin mysql/mysql-$version/
|
||||
done
|
||||
|
||||
# Copy MariaDB client tools (10.6, 12.1) from pre-built assets
|
||||
for version in 10.6 12.1; do
|
||||
mkdir -p mariadb/mariadb-$version
|
||||
cp -r ../../assets/tools/x64/mariadb/mariadb-$version/bin mariadb/mariadb-$version/
|
||||
done
|
||||
|
||||
# Make all binaries executable
|
||||
chmod +x postgresql/*/bin/*
|
||||
chmod +x mysql/*/bin/*
|
||||
chmod +x mariadb/*/bin/*
|
||||
|
||||
echo "Pre-built client tools setup complete"
|
||||
|
||||
- name: Install MongoDB Database Tools
|
||||
run: |
|
||||
cd backend/tools
|
||||
|
||||
# MongoDB Database Tools must be downloaded (not in pre-built assets)
|
||||
# They are backward compatible - single version supports all servers (4.0-8.0)
|
||||
MONGODB_TOOLS_URL="https://fastdl.mongodb.org/tools/db/mongodb-database-tools-debian12-x86_64-100.10.0.deb"
|
||||
|
||||
echo "Downloading MongoDB Database Tools..."
|
||||
wget -q "$MONGODB_TOOLS_URL" -O /tmp/mongodb-database-tools.deb
|
||||
|
||||
echo "Installing MongoDB Database Tools..."
|
||||
dpkg -i /tmp/mongodb-database-tools.deb || apt-get install -f -y --no-install-recommends
|
||||
|
||||
# Create symlinks to tools directory
|
||||
ln -sf /usr/bin/mongodump mongodb/bin/mongodump
|
||||
ln -sf /usr/bin/mongorestore mongodb/bin/mongorestore
|
||||
|
||||
rm -f /tmp/mongodb-database-tools.deb
|
||||
echo "MongoDB Database Tools installed successfully"
|
||||
|
||||
- name: Verify MariaDB client tools exist
|
||||
run: |
|
||||
cd backend/tools
|
||||
@@ -426,10 +419,28 @@ jobs:
|
||||
if: always()
|
||||
run: |
|
||||
cd backend
|
||||
# Stop and remove containers (keeping images for next run)
|
||||
docker compose -f docker-compose.yml.example down -v
|
||||
|
||||
# Clean up all data directories created by docker-compose
|
||||
echo "Cleaning up data directories..."
|
||||
rm -rf pgdata || true
|
||||
rm -rf valkey-data || true
|
||||
rm -rf mysqldata || true
|
||||
rm -rf mariadbdata || true
|
||||
rm -rf temp/nas || true
|
||||
rm -rf databasus-data || true
|
||||
|
||||
# Also clean root-level databasus-data if exists
|
||||
cd ..
|
||||
rm -rf databasus-data || true
|
||||
|
||||
echo "Cleanup complete"
|
||||
|
||||
determine-version:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
container:
|
||||
image: node:20
|
||||
needs: [test-backend, test-frontend]
|
||||
if: ${{ github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message, '[skip-release]') }}
|
||||
outputs:
|
||||
@@ -442,10 +453,9 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "20"
|
||||
- name: Configure Git for container
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Install semver
|
||||
run: npm install -g semver
|
||||
@@ -459,6 +469,7 @@ jobs:
|
||||
|
||||
- name: Analyze commits and determine version bump
|
||||
id: version_bump
|
||||
shell: bash
|
||||
run: |
|
||||
CURRENT_VERSION="${{ steps.current_version.outputs.current_version }}"
|
||||
LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
|
||||
@@ -478,7 +489,7 @@ jobs:
|
||||
HAS_FIX=false
|
||||
HAS_BREAKING=false
|
||||
|
||||
# Analyze each commit
|
||||
# Analyze each commit - USE PROCESS SUBSTITUTION to avoid subshell variable scope issues
|
||||
while IFS= read -r commit; do
|
||||
if [[ "$commit" =~ ^FEATURE ]]; then
|
||||
HAS_FEATURE=true
|
||||
@@ -496,7 +507,7 @@ jobs:
|
||||
HAS_BREAKING=true
|
||||
echo "Found BREAKING CHANGE: $commit"
|
||||
fi
|
||||
done <<< "$COMMITS"
|
||||
done < <(printf '%s\n' "$COMMITS")
|
||||
|
||||
# Determine version bump
|
||||
if [ "$HAS_BREAKING" = true ]; then
|
||||
@@ -522,10 +533,15 @@ jobs:
|
||||
fi
|
||||
|
||||
build-only:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
needs: [test-backend, test-frontend]
|
||||
if: ${{ github.ref == 'refs/heads/main' && contains(github.event.head_commit.message, '[skip-release]') }}
|
||||
steps:
|
||||
- name: Clean workspace
|
||||
run: |
|
||||
sudo rm -rf "$GITHUB_WORKSPACE"/* || true
|
||||
sudo rm -rf "$GITHUB_WORKSPACE"/.* || true
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -554,12 +570,17 @@ jobs:
|
||||
databasus/databasus:${{ github.sha }}
|
||||
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
needs: [determine-version]
|
||||
if: ${{ needs.determine-version.outputs.should_release == 'true' }}
|
||||
permissions:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Clean workspace
|
||||
run: |
|
||||
sudo rm -rf "$GITHUB_WORKSPACE"/* || true
|
||||
sudo rm -rf "$GITHUB_WORKSPACE"/.* || true
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -589,21 +610,33 @@ jobs:
|
||||
databasus/databasus:${{ github.sha }}
|
||||
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
container:
|
||||
image: node:20
|
||||
needs: [determine-version, build-and-push]
|
||||
if: ${{ needs.determine-version.outputs.should_release == 'true' }}
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Clean workspace
|
||||
run: |
|
||||
rm -rf "$GITHUB_WORKSPACE"/* || true
|
||||
rm -rf "$GITHUB_WORKSPACE"/.* || true
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Configure Git for container
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Generate changelog
|
||||
id: changelog
|
||||
shell: bash
|
||||
run: |
|
||||
NEW_VERSION="${{ needs.determine-version.outputs.new_version }}"
|
||||
LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
|
||||
@@ -623,6 +656,7 @@ jobs:
|
||||
FIXES=""
|
||||
REFACTORS=""
|
||||
|
||||
# USE PROCESS SUBSTITUTION to avoid subshell variable scope issues
|
||||
while IFS= read -r line; do
|
||||
if [ -n "$line" ]; then
|
||||
COMMIT_MSG=$(echo "$line" | cut -d'|' -f1)
|
||||
@@ -656,7 +690,7 @@ jobs:
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
done <<< "$COMMITS"
|
||||
done < <(printf '%s\n' "$COMMITS")
|
||||
|
||||
# Build changelog sections
|
||||
if [ -n "$FEATURES" ]; then
|
||||
@@ -695,16 +729,33 @@ jobs:
|
||||
prerelease: false
|
||||
|
||||
publish-helm-chart:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
container:
|
||||
image: alpine:3.19
|
||||
volumes:
|
||||
- /runner-cache/apk-cache:/etc/apk/cache
|
||||
needs: [determine-version, build-and-push]
|
||||
if: ${{ needs.determine-version.outputs.should_release == 'true' }}
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
steps:
|
||||
- name: Clean workspace
|
||||
run: |
|
||||
rm -rf "$GITHUB_WORKSPACE"/* || true
|
||||
rm -rf "$GITHUB_WORKSPACE"/.* || true
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apk add --no-cache git bash curl
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Configure Git for container
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@v4
|
||||
with:
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -1,3 +1,4 @@
|
||||
ansible/
|
||||
postgresus_data/
|
||||
postgresus-data/
|
||||
databasus-data/
|
||||
@@ -9,4 +10,5 @@ node_modules/
|
||||
/articles
|
||||
|
||||
.DS_Store
|
||||
/scripts
|
||||
/scripts
|
||||
.vscode/settings.json
|
||||
|
||||
@@ -18,6 +18,13 @@ repos:
|
||||
files: ^frontend/.*\.(ts|tsx|js|jsx)$
|
||||
pass_filenames: false
|
||||
|
||||
- id: frontend-build
|
||||
name: Frontend Build
|
||||
entry: bash -c "cd frontend && npm run build"
|
||||
language: system
|
||||
files: ^frontend/.*\.(ts|tsx|js|jsx|json|css)$
|
||||
pass_filenames: false
|
||||
|
||||
# Backend checks
|
||||
- repo: local
|
||||
hooks:
|
||||
@@ -27,3 +34,10 @@ repos:
|
||||
language: system
|
||||
files: ^backend/.*\.go$
|
||||
pass_filenames: false
|
||||
|
||||
- id: backend-go-mod-tidy
|
||||
name: Backend Go Mod Tidy
|
||||
entry: bash -c "cd backend && go mod tidy"
|
||||
language: system
|
||||
files: ^backend/.*\.go$
|
||||
pass_filenames: false
|
||||
|
||||
61
Dockerfile
61
Dockerfile
@@ -251,6 +251,37 @@ fi
|
||||
# PostgreSQL 17 binary paths
|
||||
PG_BIN="/usr/lib/postgresql/17/bin"
|
||||
|
||||
# Generate runtime configuration for frontend
|
||||
echo "Generating runtime configuration..."
|
||||
|
||||
# Detect if email is configured (both SMTP_HOST and DATABASUS_URL must be set)
|
||||
if [ -n "\${SMTP_HOST:-}" ] && [ -n "\${DATABASUS_URL:-}" ]; then
|
||||
IS_EMAIL_CONFIGURED="true"
|
||||
else
|
||||
IS_EMAIL_CONFIGURED="false"
|
||||
fi
|
||||
|
||||
cat > /app/ui/build/runtime-config.js <<JSEOF
|
||||
// Runtime configuration injected at container startup
|
||||
// This file is generated dynamically and should not be edited manually
|
||||
window.__RUNTIME_CONFIG__ = {
|
||||
IS_CLOUD: '\${IS_CLOUD:-false}',
|
||||
GITHUB_CLIENT_ID: '\${GITHUB_CLIENT_ID:-}',
|
||||
GOOGLE_CLIENT_ID: '\${GOOGLE_CLIENT_ID:-}',
|
||||
IS_EMAIL_CONFIGURED: '\$IS_EMAIL_CONFIGURED',
|
||||
CLOUDFLARE_TURNSTILE_SITE_KEY: '\${CLOUDFLARE_TURNSTILE_SITE_KEY:-}'
|
||||
};
|
||||
JSEOF
|
||||
|
||||
# Inject analytics script if provided (only if not already injected)
|
||||
if [ -n "\${ANALYTICS_SCRIPT:-}" ]; then
|
||||
if ! grep -q "rybbit.databasus.com" /app/ui/build/index.html 2>/dev/null; then
|
||||
echo "Injecting analytics script..."
|
||||
sed -i "s#</head># \${ANALYTICS_SCRIPT}\\
|
||||
</head>#" /app/ui/build/index.html
|
||||
fi
|
||||
fi
|
||||
|
||||
# Ensure proper ownership of data directory
|
||||
echo "Setting up data directory permissions..."
|
||||
mkdir -p /databasus-data/pgdata
|
||||
@@ -372,9 +403,37 @@ SQL
|
||||
|
||||
# Start the main application
|
||||
echo "Starting Databasus application..."
|
||||
|
||||
# Check and warn about external database/Valkey usage
|
||||
if [ -n "\${DANGEROUS_EXTERNAL_DATABASE_DSN:-}" ]; then
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "WARNING: Using external database"
|
||||
echo "=========================================="
|
||||
echo "DANGEROUS_EXTERNAL_DATABASE_DSN is set."
|
||||
echo "Application will connect to external PostgreSQL instead of internal instance."
|
||||
echo "Internal PostgreSQL is still running in the background."
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
fi
|
||||
|
||||
if [ -n "\${DANGEROUS_VALKEY_HOST:-}" ]; then
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "WARNING: Using external Valkey"
|
||||
echo "=========================================="
|
||||
echo "DANGEROUS_VALKEY_HOST is set."
|
||||
echo "Application will connect to external Valkey instead of internal instance."
|
||||
echo "Internal Valkey is still running in the background."
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
fi
|
||||
|
||||
exec ./main
|
||||
EOF
|
||||
|
||||
LABEL org.opencontainers.image.source="https://github.com/databasus/databasus"
|
||||
|
||||
RUN chmod +x /app/start.sh
|
||||
|
||||
EXPOSE 4005
|
||||
@@ -383,4 +442,4 @@ EXPOSE 4005
|
||||
VOLUME ["/databasus-data"]
|
||||
|
||||
ENTRYPOINT ["/app/start.sh"]
|
||||
CMD []
|
||||
CMD []
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
[](https://www.mongodb.com/)
|
||||
<br />
|
||||
[](LICENSE)
|
||||
[](https://hub.docker.com/r/rostislavdugin/postgresus)
|
||||
[](https://hub.docker.com/r/databasus/databasus)
|
||||
[](https://github.com/databasus/databasus)
|
||||
[](https://github.com/databasus/databasus)
|
||||
[](https://github.com/databasus/databasus)
|
||||
@@ -31,8 +31,6 @@
|
||||
<img src="assets/dashboard-dark.svg" alt="Databasus Dark Dashboard" width="800" style="margin-bottom: 10px;"/>
|
||||
|
||||
<img src="assets/dashboard.svg" alt="Databasus Dashboard" width="800"/>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
@@ -1,152 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
Always place private methods to the bottom of file
|
||||
|
||||
**This rule applies to ALL Go files including tests, services, controllers, repositories, etc.**
|
||||
|
||||
In Go, exported (public) functions/methods start with uppercase letters, while unexported (private) ones start with lowercase letters.
|
||||
|
||||
## Structure Order:
|
||||
|
||||
1. Type definitions and constants
|
||||
2. Public methods/functions (uppercase)
|
||||
3. Private methods/functions (lowercase)
|
||||
|
||||
## Examples:
|
||||
|
||||
### Service with methods:
|
||||
|
||||
```go
|
||||
type UserService struct {
|
||||
repository *UserRepository
|
||||
}
|
||||
|
||||
// Public methods first
|
||||
func (s *UserService) CreateUser(user *User) error {
|
||||
if err := s.validateUser(user); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.repository.Save(user)
|
||||
}
|
||||
|
||||
func (s *UserService) GetUser(id uuid.UUID) (*User, error) {
|
||||
return s.repository.FindByID(id)
|
||||
}
|
||||
|
||||
// Private methods at the bottom
|
||||
func (s *UserService) validateUser(user *User) error {
|
||||
if user.Name == "" {
|
||||
return errors.New("name is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
### Package-level functions:
|
||||
|
||||
```go
|
||||
package utils
|
||||
|
||||
// Public functions first
|
||||
func ProcessData(data []byte) (Result, error) {
|
||||
cleaned := sanitizeInput(data)
|
||||
return parseData(cleaned)
|
||||
}
|
||||
|
||||
func ValidateInput(input string) bool {
|
||||
return isValidFormat(input) && checkLength(input)
|
||||
}
|
||||
|
||||
// Private functions at the bottom
|
||||
func sanitizeInput(data []byte) []byte {
|
||||
// implementation
|
||||
}
|
||||
|
||||
func parseData(data []byte) (Result, error) {
|
||||
// implementation
|
||||
}
|
||||
|
||||
func isValidFormat(input string) bool {
|
||||
// implementation
|
||||
}
|
||||
|
||||
func checkLength(input string) bool {
|
||||
// implementation
|
||||
}
|
||||
```
|
||||
|
||||
### Test files:
|
||||
|
||||
```go
|
||||
package user_test
|
||||
|
||||
// Public test functions first
|
||||
func Test_CreateUser_ValidInput_UserCreated(t *testing.T) {
|
||||
user := createTestUser()
|
||||
result, err := service.CreateUser(user)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
|
||||
func Test_GetUser_ExistingUser_ReturnsUser(t *testing.T) {
|
||||
user := createTestUser()
|
||||
// test implementation
|
||||
}
|
||||
|
||||
// Private helper functions at the bottom
|
||||
func createTestUser() *User {
|
||||
return &User{
|
||||
Name: "Test User",
|
||||
Email: "test@example.com",
|
||||
}
|
||||
}
|
||||
|
||||
func setupTestDatabase() *Database {
|
||||
// setup implementation
|
||||
}
|
||||
```
|
||||
|
||||
### Controller example:
|
||||
|
||||
```go
|
||||
type ProjectController struct {
|
||||
service *ProjectService
|
||||
}
|
||||
|
||||
// Public HTTP handlers first
|
||||
func (c *ProjectController) CreateProject(ctx *gin.Context) {
|
||||
var request CreateProjectRequest
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
c.handleError(ctx, err)
|
||||
return
|
||||
}
|
||||
// handler logic
|
||||
}
|
||||
|
||||
func (c *ProjectController) GetProject(ctx *gin.Context) {
|
||||
projectID := c.extractProjectID(ctx)
|
||||
// handler logic
|
||||
}
|
||||
|
||||
// Private helper methods at the bottom
|
||||
func (c *ProjectController) handleError(ctx *gin.Context, err error) {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
}
|
||||
|
||||
func (c *ProjectController) extractProjectID(ctx *gin.Context) uuid.UUID {
|
||||
return uuid.MustParse(ctx.Param("projectId"))
|
||||
}
|
||||
```
|
||||
|
||||
## Key Points:
|
||||
|
||||
- **Exported/Public** = starts with uppercase letter (CreateUser, GetProject)
|
||||
- **Unexported/Private** = starts with lowercase letter (validateUser, handleError)
|
||||
- This improves code readability by showing the public API first
|
||||
- Private helpers are implementation details, so they go at the bottom
|
||||
- Apply this rule consistently across ALL Go files in the project
|
||||
@@ -1,45 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
## Comment Guidelines
|
||||
|
||||
1. **No obvious comments** - Don't state what the code already clearly shows
|
||||
2. **Functions and variables should have meaningful names** - Code should be self-documenting
|
||||
3. **Comments for unclear code only** - Only add comments when code logic isn't immediately clear
|
||||
|
||||
## Key Principles:
|
||||
|
||||
- **Code should tell a story** - Use descriptive variable and function names
|
||||
- **Comments explain WHY, not WHAT** - The code shows what happens, comments explain business logic or complex decisions
|
||||
- **Prefer refactoring over commenting** - If code needs explaining, consider making it clearer instead
|
||||
- **API documentation is required** - Swagger comments for all HTTP endpoints are mandatory
|
||||
- **Complex algorithms deserve comments** - Mathematical formulas, business rules, or non-obvious optimizations
|
||||
|
||||
Example of useless comment:
|
||||
|
||||
1.
|
||||
|
||||
```sql
|
||||
// Create projects table
|
||||
CREATE TABLE projects (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
name TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
```
|
||||
|
||||
2.
|
||||
|
||||
```go
|
||||
// Create test project
|
||||
project := CreateTestProject(projectName, user, router)
|
||||
```
|
||||
|
||||
3.
|
||||
|
||||
```go
|
||||
// CreateValidLogItems creates valid log items for testing
|
||||
func CreateValidLogItems(count int, uniqueID string) []logs_receiving.LogItemRequestDTO {
|
||||
```
|
||||
@@ -1,133 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
1. When we write controller:
|
||||
|
||||
- we combine all routes to single controller
|
||||
- names them as .WhatWeDo (not "handlers") concept
|
||||
|
||||
2. We use gin and \*gin.Context for all routes.
|
||||
Example:
|
||||
|
||||
func (c *TasksController) GetAvailableTasks(ctx *gin.Context) ...
|
||||
|
||||
3. We document all routes with Swagger in the following format:
|
||||
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
user_models "databasus-backend/internal/features/users/models"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
)
|
||||
|
||||
type AuditLogController struct {
|
||||
auditLogService \*AuditLogService
|
||||
}
|
||||
|
||||
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// All audit log endpoints require authentication (handled in main.go)
|
||||
auditRoutes := router.Group("/audit-logs")
|
||||
|
||||
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
|
||||
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
|
||||
|
||||
}
|
||||
|
||||
// GetGlobalAuditLogs
|
||||
// @Summary Get global audit logs (ADMIN only)
|
||||
// @Description Retrieve all audit logs across the system
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/global [get]
|
||||
func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(\*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "only administrators can view global audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
|
||||
}
|
||||
|
||||
// GetUserAuditLogs
|
||||
// @Summary Get user audit logs
|
||||
// @Description Retrieve audit logs for a specific user
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param userId path string true "User ID"
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/users/{userId} [get]
|
||||
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(\*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
userIDStr := ctx.Param("userId")
|
||||
targetUserID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "insufficient permissions to view user audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
|
||||
}
|
||||
@@ -1,671 +0,0 @@
|
||||
---
|
||||
alwaysApply: false
|
||||
---
|
||||
|
||||
This is example of CRUD:
|
||||
|
||||
------ backend/internal/features/audit_logs/controller.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
user_models "databasus-backend/internal/features/users/models"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogController struct {
|
||||
auditLogService *AuditLogService
|
||||
}
|
||||
|
||||
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// All audit log endpoints require authentication (handled in main.go)
|
||||
auditRoutes := router.Group("/audit-logs")
|
||||
|
||||
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
|
||||
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
|
||||
}
|
||||
|
||||
// GetGlobalAuditLogs
|
||||
// @Summary Get global audit logs (ADMIN only)
|
||||
// @Description Retrieve all audit logs across the system
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/global [get]
|
||||
func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "only administrators can view global audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// GetUserAuditLogs
|
||||
// @Summary Get user audit logs
|
||||
// @Description Retrieve audit logs for a specific user
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param userId path string true "User ID"
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/users/{userId} [get]
|
||||
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
userIDStr := ctx.Param("userId")
|
||||
targetUserID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "insufficient permissions to view user audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/controller_test.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
"databasus-backend/internal/storage"
|
||||
test_utils "databasus-backend/internal/util/testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_GetGlobalAuditLogs_AdminSucceedsAndMemberGetsForbidden(t *testing.T) {
|
||||
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
|
||||
memberUser := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
service := GetAuditLogService()
|
||||
projectID := uuid.New()
|
||||
|
||||
// Create test logs
|
||||
createAuditLog(service, "Test log with user", &adminUser.UserID, nil)
|
||||
createAuditLog(service, "Test log with project", nil, &projectID)
|
||||
createAuditLog(service, "Test log standalone", nil, nil)
|
||||
|
||||
// Test ADMIN can access global logs
|
||||
var response GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
"/api/v1/audit-logs/global?limit=10", "Bearer "+adminUser.Token, http.StatusOK, &response)
|
||||
|
||||
assert.GreaterOrEqual(t, len(response.AuditLogs), 3)
|
||||
assert.GreaterOrEqual(t, response.Total, int64(3))
|
||||
|
||||
messages := extractMessages(response.AuditLogs)
|
||||
assert.Contains(t, messages, "Test log with user")
|
||||
assert.Contains(t, messages, "Test log with project")
|
||||
assert.Contains(t, messages, "Test log standalone")
|
||||
|
||||
// Test MEMBER cannot access global logs
|
||||
resp := test_utils.MakeGetRequest(t, router, "/api/v1/audit-logs/global",
|
||||
"Bearer "+memberUser.Token, http.StatusForbidden)
|
||||
assert.Contains(t, string(resp.Body), "only administrators can view global audit logs")
|
||||
}
|
||||
|
||||
func Test_GetUserAuditLogs_PermissionsEnforcedCorrectly(t *testing.T) {
|
||||
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
|
||||
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
service := GetAuditLogService()
|
||||
projectID := uuid.New()
|
||||
|
||||
// Create test logs for different users
|
||||
createAuditLog(service, "Test log user1 first", &user1.UserID, nil)
|
||||
createAuditLog(service, "Test log user1 second", &user1.UserID, &projectID)
|
||||
createAuditLog(service, "Test log user2 first", &user2.UserID, nil)
|
||||
createAuditLog(service, "Test log user2 second", &user2.UserID, &projectID)
|
||||
createAuditLog(service, "Test project log", nil, &projectID)
|
||||
|
||||
// Test ADMIN can view any user's logs
|
||||
var user1Response GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/users/%s?limit=10", user1.UserID.String()),
|
||||
"Bearer "+adminUser.Token, http.StatusOK, &user1Response)
|
||||
|
||||
assert.Equal(t, 2, len(user1Response.AuditLogs))
|
||||
messages := extractMessages(user1Response.AuditLogs)
|
||||
assert.Contains(t, messages, "Test log user1 first")
|
||||
assert.Contains(t, messages, "Test log user1 second")
|
||||
|
||||
// Test user can view own logs
|
||||
var ownLogsResponse GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/users/%s", user2.UserID.String()),
|
||||
"Bearer "+user2.Token, http.StatusOK, &ownLogsResponse)
|
||||
assert.Equal(t, 2, len(ownLogsResponse.AuditLogs))
|
||||
|
||||
// Test user cannot view other user's logs
|
||||
resp := test_utils.MakeGetRequest(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/users/%s", user1.UserID.String()),
|
||||
"Bearer "+user2.Token, http.StatusForbidden)
|
||||
|
||||
assert.Contains(t, string(resp.Body), "insufficient permissions")
|
||||
}
|
||||
|
||||
func Test_FilterAuditLogsByTime_ReturnsOnlyLogsBeforeDate(t *testing.T) {
|
||||
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
|
||||
router := createRouter()
|
||||
service := GetAuditLogService()
|
||||
db := storage.GetDb()
|
||||
baseTime := time.Now().UTC()
|
||||
|
||||
// Create logs with different timestamps
|
||||
createTimedLog(db, &adminUser.UserID, "Test old log", baseTime.Add(-2*time.Hour))
|
||||
createTimedLog(db, &adminUser.UserID, "Test recent log", baseTime.Add(-30*time.Minute))
|
||||
createAuditLog(service, "Test current log", &adminUser.UserID, nil)
|
||||
|
||||
// Test filtering - get logs before 1 hour ago
|
||||
beforeTime := baseTime.Add(-1 * time.Hour)
|
||||
var filteredResponse GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/global?beforeDate=%s", beforeTime.Format(time.RFC3339)),
|
||||
"Bearer "+adminUser.Token, http.StatusOK, &filteredResponse)
|
||||
|
||||
// Verify only old log is returned
|
||||
messages := extractMessages(filteredResponse.AuditLogs)
|
||||
assert.Contains(t, messages, "Test old log")
|
||||
assert.NotContains(t, messages, "Test recent log")
|
||||
assert.NotContains(t, messages, "Test current log")
|
||||
|
||||
// Test without filter - should get all logs
|
||||
var allResponse GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router, "/api/v1/audit-logs/global",
|
||||
"Bearer "+adminUser.Token, http.StatusOK, &allResponse)
|
||||
assert.GreaterOrEqual(t, len(allResponse.AuditLogs), 3)
|
||||
}
|
||||
|
||||
func createRouter() *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
SetupDependencies()
|
||||
|
||||
v1 := router.Group("/api/v1")
|
||||
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
|
||||
GetAuditLogController().RegisterRoutes(protected.(*gin.RouterGroup))
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/di.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var auditLogRepository = &AuditLogRepository{}
|
||||
var auditLogService = &AuditLogService{
|
||||
auditLogRepository: auditLogRepository,
|
||||
logger: logger.GetLogger(),
|
||||
}
|
||||
var auditLogController = &AuditLogController{
|
||||
auditLogService: auditLogService,
|
||||
}
|
||||
|
||||
func GetAuditLogService() *AuditLogService {
|
||||
return auditLogService
|
||||
}
|
||||
|
||||
func GetAuditLogController() *AuditLogController {
|
||||
return auditLogController
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
users_services.GetUserService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/dto.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import "time"
|
||||
|
||||
type GetAuditLogsRequest struct {
|
||||
Limit int `form:"limit" json:"limit"`
|
||||
Offset int `form:"offset" json:"offset"`
|
||||
BeforeDate *time.Time `form:"beforeDate" json:"beforeDate"`
|
||||
}
|
||||
|
||||
type GetAuditLogsResponse struct {
|
||||
AuditLogs []*AuditLog `json:"auditLogs"`
|
||||
Total int64 `json:"total"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/models.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLog struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id"`
|
||||
UserID *uuid.UUID `json:"userId" gorm:"column:user_id"`
|
||||
ProjectID *uuid.UUID `json:"projectId" gorm:"column:project_id"`
|
||||
Message string `json:"message" gorm:"column:message"`
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
|
||||
}
|
||||
|
||||
func (AuditLog) TableName() string {
|
||||
return "audit_logs"
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/repository.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/storage"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogRepository struct{}
|
||||
|
||||
func (r *AuditLogRepository) Create(auditLog *AuditLog) error {
|
||||
if auditLog.ID == uuid.Nil {
|
||||
auditLog.ID = uuid.New()
|
||||
}
|
||||
|
||||
return storage.GetDb().Create(auditLog).Error
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) GetGlobal(limit, offset int, beforeDate *time.Time) ([]*AuditLog, error) {
|
||||
var auditLogs []*AuditLog
|
||||
|
||||
query := storage.GetDb().Order("created_at DESC")
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&auditLogs).Error
|
||||
|
||||
return auditLogs, err
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) GetByUser(
|
||||
userID uuid.UUID,
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLog, error) {
|
||||
var auditLogs []*AuditLog
|
||||
|
||||
query := storage.GetDb().
|
||||
Where("user_id = ?", userID).
|
||||
Order("created_at DESC")
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&auditLogs).Error
|
||||
|
||||
return auditLogs, err
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) GetByProject(
|
||||
projectID uuid.UUID,
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLog, error) {
|
||||
var auditLogs []*AuditLog
|
||||
|
||||
query := storage.GetDb().
|
||||
Where("project_id = ?", projectID).
|
||||
Order("created_at DESC")
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&auditLogs).Error
|
||||
|
||||
return auditLogs, err
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) CountGlobal(beforeDate *time.Time) (int64, error) {
|
||||
var count int64
|
||||
query := storage.GetDb().Model(&AuditLog{})
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/service.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
user_models "databasus-backend/internal/features/users/models"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogService struct {
|
||||
auditLogRepository *AuditLogRepository
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *AuditLogService) WriteAuditLog(
|
||||
message string,
|
||||
userID *uuid.UUID,
|
||||
projectID *uuid.UUID,
|
||||
) {
|
||||
auditLog := &AuditLog{
|
||||
UserID: userID,
|
||||
ProjectID: projectID,
|
||||
Message: message,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
err := s.auditLogRepository.Create(auditLog)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to create audit log", "error", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuditLogService) CreateAuditLog(auditLog *AuditLog) error {
|
||||
return s.auditLogRepository.Create(auditLog)
|
||||
}
|
||||
|
||||
func (s *AuditLogService) GetGlobalAuditLogs(
|
||||
user *user_models.User,
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
if user.Role != user_enums.UserRoleAdmin {
|
||||
return nil, errors.New("only administrators can view global audit logs")
|
||||
}
|
||||
|
||||
limit := request.Limit
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
offset := max(request.Offset, 0)
|
||||
|
||||
auditLogs, err := s.auditLogRepository.GetGlobal(limit, offset, request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
total, err := s.auditLogRepository.CountGlobal(request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetAuditLogsResponse{
|
||||
AuditLogs: auditLogs,
|
||||
Total: total,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AuditLogService) GetUserAuditLogs(
|
||||
targetUserID uuid.UUID,
|
||||
user *user_models.User,
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
// Users can view their own logs, ADMIN can view any user's logs
|
||||
if user.Role != user_enums.UserRoleAdmin && user.ID != targetUserID {
|
||||
return nil, errors.New("insufficient permissions to view user audit logs")
|
||||
}
|
||||
|
||||
limit := request.Limit
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
offset := max(request.Offset, 0)
|
||||
|
||||
auditLogs, err := s.auditLogRepository.GetByUser(targetUserID, limit, offset, request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetAuditLogsResponse{
|
||||
AuditLogs: auditLogs,
|
||||
Total: int64(len(auditLogs)),
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AuditLogService) GetProjectAuditLogs(
|
||||
projectID uuid.UUID,
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
limit := request.Limit
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
offset := max(request.Offset, 0)
|
||||
|
||||
auditLogs, err := s.auditLogRepository.GetByProject(projectID, limit, offset, request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetAuditLogsResponse{
|
||||
AuditLogs: auditLogs,
|
||||
Total: int64(len(auditLogs)),
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/service_test.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func Test_AuditLogs_ProjectSpecificLogs(t *testing.T) {
|
||||
service := GetAuditLogService()
|
||||
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
project1ID, project2ID := uuid.New(), uuid.New()
|
||||
|
||||
// Create test logs for projects
|
||||
createAuditLog(service, "Test project1 log first", &user1.UserID, &project1ID)
|
||||
createAuditLog(service, "Test project1 log second", &user2.UserID, &project1ID)
|
||||
createAuditLog(service, "Test project2 log first", &user1.UserID, &project2ID)
|
||||
createAuditLog(service, "Test project2 log second", &user2.UserID, &project2ID)
|
||||
createAuditLog(service, "Test no project log", &user1.UserID, nil)
|
||||
|
||||
request := &GetAuditLogsRequest{Limit: 10, Offset: 0}
|
||||
|
||||
// Test project 1 logs
|
||||
project1Response, err := service.GetProjectAuditLogs(project1ID, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(project1Response.AuditLogs))
|
||||
|
||||
messages := extractMessages(project1Response.AuditLogs)
|
||||
assert.Contains(t, messages, "Test project1 log first")
|
||||
assert.Contains(t, messages, "Test project1 log second")
|
||||
for _, log := range project1Response.AuditLogs {
|
||||
assert.Equal(t, &project1ID, log.ProjectID)
|
||||
}
|
||||
|
||||
// Test project 2 logs
|
||||
project2Response, err := service.GetProjectAuditLogs(project2ID, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(project2Response.AuditLogs))
|
||||
|
||||
messages2 := extractMessages(project2Response.AuditLogs)
|
||||
assert.Contains(t, messages2, "Test project2 log first")
|
||||
assert.Contains(t, messages2, "Test project2 log second")
|
||||
|
||||
// Test pagination
|
||||
limitedResponse, err := service.GetProjectAuditLogs(project1ID,
|
||||
&GetAuditLogsRequest{Limit: 1, Offset: 0})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(limitedResponse.AuditLogs))
|
||||
assert.Equal(t, 1, limitedResponse.Limit)
|
||||
|
||||
// Test beforeDate filter
|
||||
beforeTime := time.Now().UTC().Add(-1 * time.Minute)
|
||||
filteredResponse, err := service.GetProjectAuditLogs(project1ID,
|
||||
&GetAuditLogsRequest{Limit: 10, BeforeDate: &beforeTime})
|
||||
assert.NoError(t, err)
|
||||
for _, log := range filteredResponse.AuditLogs {
|
||||
assert.True(t, log.CreatedAt.Before(beforeTime))
|
||||
}
|
||||
}
|
||||
|
||||
func createAuditLog(service *AuditLogService, message string, userID, projectID *uuid.UUID) {
|
||||
service.WriteAuditLog(message, userID, projectID)
|
||||
}
|
||||
|
||||
func extractMessages(logs []*AuditLog) []string {
|
||||
messages := make([]string, len(logs))
|
||||
for i, log := range logs {
|
||||
messages[i] = log.Message
|
||||
}
|
||||
return messages
|
||||
}
|
||||
|
||||
func createTimedLog(db *gorm.DB, userID *uuid.UUID, message string, createdAt time.Time) {
|
||||
log := &AuditLog{
|
||||
ID: uuid.New(),
|
||||
UserID: userID,
|
||||
Message: message,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
db.Create(log)
|
||||
}
|
||||
|
||||
```
|
||||
@@ -1,74 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
For DI files use implicit fields declaration styles (espesially
|
||||
for controllers, services, repositories, use cases, etc., not simple
|
||||
data structures).
|
||||
|
||||
So, instead of:
|
||||
|
||||
var orderController = &OrderController{
|
||||
orderService: orderService,
|
||||
botUserService: bot_users.GetBotUserService(),
|
||||
botService: bots.GetBotService(),
|
||||
userService: users.GetUserService(),
|
||||
}
|
||||
|
||||
Use:
|
||||
|
||||
var orderController = &OrderController{
|
||||
orderService,
|
||||
bot_users.GetBotUserService(),
|
||||
bots.GetBotService(),
|
||||
users.GetUserService(),
|
||||
}
|
||||
|
||||
This is needed to avoid forgetting to update DI style
|
||||
when we add new dependency.
|
||||
|
||||
---
|
||||
|
||||
Please force such usage if file look like this (see some
|
||||
services\controllers\repos definitions and getters):
|
||||
|
||||
var orderBackgroundService = &OrderBackgroundService{
|
||||
orderService: orderService,
|
||||
orderPaymentRepository: orderPaymentRepository,
|
||||
botService: bots.GetBotService(),
|
||||
paymentSettingsService: payment_settings.GetPaymentSettingsService(),
|
||||
|
||||
orderSubscriptionListeners: []OrderSubscriptionListener{},
|
||||
}
|
||||
|
||||
var orderController = &OrderController{
|
||||
orderService: orderService,
|
||||
botUserService: bot_users.GetBotUserService(),
|
||||
botService: bots.GetBotService(),
|
||||
userService: users.GetUserService(),
|
||||
}
|
||||
|
||||
func GetUniquePaymentRepository() *repositories.UniquePaymentRepository {
|
||||
return uniquePaymentRepository
|
||||
}
|
||||
|
||||
func GetOrderPaymentRepository() *repositories.OrderPaymentRepository {
|
||||
return orderPaymentRepository
|
||||
}
|
||||
|
||||
func GetOrderService() *OrderService {
|
||||
return orderService
|
||||
}
|
||||
|
||||
func GetOrderController() *OrderController {
|
||||
return orderController
|
||||
}
|
||||
|
||||
func GetOrderBackgroundService() *OrderBackgroundService {
|
||||
return orderBackgroundService
|
||||
}
|
||||
|
||||
func GetOrderRepository() *repositories.OrderRepository {
|
||||
return orderRepository
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
When writting migrations:
|
||||
|
||||
- write them for PostgreSQL
|
||||
- for PRIMARY UUID keys use gen_random_uuid()
|
||||
- for time use TIMESTAMPTZ (timestamp with zone)
|
||||
- split table, constraint and indexes declaration (table first, them other one by one)
|
||||
- format SQL in pretty way (add spaces, align columns types), constraints split by lines. The example:
|
||||
|
||||
CREATE TABLE marketplace_info (
|
||||
bot_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
title TEXT NOT NULL,
|
||||
description TEXT NOT NULL,
|
||||
short_description TEXT NOT NULL,
|
||||
tutorial_url TEXT,
|
||||
info_order BIGINT NOT NULL DEFAULT 0,
|
||||
is_published BOOLEAN NOT NULL DEFAULT FALSE
|
||||
);
|
||||
|
||||
ALTER TABLE marketplace_info_images
|
||||
ADD CONSTRAINT fk_marketplace_info_images_bot_id
|
||||
FOREIGN KEY (bot_id)
|
||||
REFERENCES marketplace_info (bot_id);
|
||||
@@ -1,12 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
When applying changes, do not forget to refactor old code.
|
||||
|
||||
You can shortify, make more readable, improve code quality, etc.
|
||||
Common logic can be extracted to functions, constants, files, etc.
|
||||
|
||||
After each large change with more than ~50-100 lines of code - always run `make lint` (from backend root folder) and, if you change frontend, run `npm run format` (from frontend root folder).
|
||||
@@ -1,147 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
After writing tests, always launch them and verify that they pass.
|
||||
|
||||
## Test Naming Format
|
||||
|
||||
Use these naming patterns:
|
||||
|
||||
- `Test_WhatWeDo_WhatWeExpect`
|
||||
- `Test_WhatWeDo_WhichConditions_WhatWeExpect`
|
||||
|
||||
## Examples from Real Codebase:
|
||||
|
||||
- `Test_CreateApiKey_WhenUserIsProjectOwner_ApiKeyCreated`
|
||||
- `Test_UpdateProject_WhenUserIsProjectAdmin_ProjectUpdated`
|
||||
- `Test_DeleteApiKey_WhenUserIsProjectMember_ReturnsForbidden`
|
||||
- `Test_GetProjectAuditLogs_WithDifferentUserRoles_EnforcesPermissionsCorrectly`
|
||||
- `Test_ProjectLifecycleE2E_CompletesSuccessfully`
|
||||
|
||||
## Testing Philosophy
|
||||
|
||||
**Prefer Controllers Over Unit Tests:**
|
||||
|
||||
- Test through HTTP endpoints via controllers whenever possible
|
||||
- Avoid testing repositories, services in isolation - test via API instead
|
||||
- Only use unit tests for complex model logic when no API exists
|
||||
- Name test files `controller_test.go` or `service_test.go`, not `integration_test.go`
|
||||
|
||||
**Extract Common Logic to Testing Utilities:**
|
||||
|
||||
- Create `testing.go` or `testing/testing.go` files for shared test utilities
|
||||
- Extract router creation, user setup, models creation helpers (in API, not just structs creation)
|
||||
- Reuse common patterns across different test files
|
||||
|
||||
**Refactor Existing Tests:**
|
||||
|
||||
- When working with existing tests, always look for opportunities to refactor and improve
|
||||
- Extract repetitive setup code to common utilities
|
||||
- Simplify complex tests by breaking them into smaller, focused tests
|
||||
- Replace inline test data creation with reusable helper functions
|
||||
- Consolidate similar test patterns across different test files
|
||||
- Make tests more readable and maintainable for other developers
|
||||
|
||||
## Testing Utilities Structure
|
||||
|
||||
**Create `testing.go` or `testing/testing.go` files with common utilities:**
|
||||
|
||||
```go
|
||||
package projects_testing
|
||||
|
||||
// CreateTestRouter creates unified router for all controllers
|
||||
func CreateTestRouter(controllers ...ControllerInterface) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
v1 := router.Group("/api/v1")
|
||||
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
|
||||
|
||||
for _, controller := range controllers {
|
||||
if routerGroup, ok := protected.(*gin.RouterGroup); ok {
|
||||
controller.RegisterRoutes(routerGroup)
|
||||
}
|
||||
}
|
||||
return router
|
||||
}
|
||||
|
||||
// CreateTestProjectViaAPI creates project through HTTP API
|
||||
func CreateTestProjectViaAPI(name string, owner *users_dto.SignInResponseDTO, router *gin.Engine) (*projects_models.Project, string) {
|
||||
request := projects_dto.CreateProjectRequestDTO{Name: name}
|
||||
w := MakeAPIRequest(router, "POST", "/api/v1/projects", "Bearer "+owner.Token, request)
|
||||
// Handle response...
|
||||
return project, owner.Token
|
||||
}
|
||||
|
||||
// AddMemberToProject adds member via API call
|
||||
func AddMemberToProject(project *projects_models.Project, member *users_dto.SignInResponseDTO, role users_enums.ProjectRole, ownerToken string, router *gin.Engine) {
|
||||
// Implementation...
|
||||
}
|
||||
```
|
||||
|
||||
## Controller Test Examples
|
||||
|
||||
**Permission-based testing:**
|
||||
|
||||
```go
|
||||
func Test_CreateApiKey_WhenUserIsProjectOwner_ApiKeyCreated(t *testing.T) {
|
||||
router := CreateApiKeyTestRouter(GetProjectController(), GetMembershipController())
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
project, _ := projects_testing.CreateTestProjectViaAPI("Test Project", owner, router)
|
||||
|
||||
request := CreateApiKeyRequestDTO{Name: "Test API Key"}
|
||||
var response ApiKey
|
||||
test_utils.MakePostRequestAndUnmarshal(t, router, "/api/v1/projects/api-keys/"+project.ID.String(), "Bearer "+owner.Token, request, http.StatusOK, &response)
|
||||
|
||||
assert.Equal(t, "Test API Key", response.Name)
|
||||
assert.NotEmpty(t, response.Token)
|
||||
}
|
||||
```
|
||||
|
||||
**Cross-project security testing:**
|
||||
|
||||
```go
|
||||
func Test_UpdateApiKey_WithApiKeyFromDifferentProject_ReturnsBadRequest(t *testing.T) {
|
||||
router := CreateApiKeyTestRouter(GetProjectController(), GetMembershipController())
|
||||
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
project1, _ := projects_testing.CreateTestProjectViaAPI("Project 1", owner1, router)
|
||||
project2, _ := projects_testing.CreateTestProjectViaAPI("Project 2", owner2, router)
|
||||
|
||||
apiKey := CreateTestApiKey("Cross Project Key", project1.ID, owner1.Token, router)
|
||||
|
||||
// Try to update via different project endpoint
|
||||
request := UpdateApiKeyRequestDTO{Name: &"Hacked Key"}
|
||||
resp := test_utils.MakePutRequest(t, router, "/api/v1/projects/api-keys/"+project2.ID.String()+"/"+apiKey.ID.String(), "Bearer "+owner2.Token, request, http.StatusBadRequest)
|
||||
|
||||
assert.Contains(t, string(resp.Body), "API key does not belong to this project")
|
||||
}
|
||||
```
|
||||
|
||||
**E2E lifecycle testing:**
|
||||
|
||||
```go
|
||||
func Test_ProjectLifecycleE2E_CompletesSuccessfully(t *testing.T) {
|
||||
router := projects_testing.CreateTestRouter(GetProjectController(), GetMembershipController())
|
||||
|
||||
// 1. Create project
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
project := projects_testing.CreateTestProject("E2E Project", owner, router)
|
||||
|
||||
// 2. Add member
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
projects_testing.AddMemberToProject(project, member, users_enums.ProjectRoleMember, owner.Token, router)
|
||||
|
||||
// 3. Promote to admin
|
||||
projects_testing.ChangeMemberRole(project, member.UserID, users_enums.ProjectRoleAdmin, owner.Token, router)
|
||||
|
||||
// 4. Transfer ownership
|
||||
projects_testing.TransferProjectOwnership(project, member.UserID, owner.Token, router)
|
||||
|
||||
// 5. Verify new owner can manage project
|
||||
finalProject := projects_testing.GetProject(project.ID, member.Token, router)
|
||||
assert.Equal(t, project.ID, finalProject.ID)
|
||||
}
|
||||
```
|
||||
@@ -1,6 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
Always use time.Now().UTC() instead of time.Now()
|
||||
@@ -2,8 +2,18 @@
|
||||
DEV_DB_NAME=databasus
|
||||
DEV_DB_USERNAME=postgres
|
||||
DEV_DB_PASSWORD=Q1234567
|
||||
#app
|
||||
# app
|
||||
ENV_MODE=development
|
||||
# logging
|
||||
SHOW_DB_INSTALLATION_VERIFICATION_LOGS=true
|
||||
VICTORIA_LOGS_URL=http://localhost:9428
|
||||
VICTORIA_LOGS_PASSWORD=devpassword
|
||||
# tests
|
||||
TEST_LOCALHOST=localhost
|
||||
IS_SKIP_EXTERNAL_RESOURCES_TESTS=false
|
||||
# cloudflare turnstile
|
||||
CLOUDFLARE_TURNSTILE_SITE_KEY=
|
||||
CLOUDFLARE_TURNSTILE_SECRET_KEY=
|
||||
# db
|
||||
DATABASE_DSN=host=dev-db user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
|
||||
DATABASE_URL=postgres://postgres:Q1234567@dev-db:5437/databasus?sslmode=disable
|
||||
|
||||
3
backend/.gitignore
vendored
3
backend/.gitignore
vendored
@@ -18,4 +18,5 @@ pgdata-for-restore/
|
||||
temp/
|
||||
cmd.exe
|
||||
temp/
|
||||
valkey-data/
|
||||
valkey-data/
|
||||
victoria-logs-data/
|
||||
@@ -2,10 +2,10 @@ run:
|
||||
go run cmd/main.go
|
||||
|
||||
test:
|
||||
go test -p=1 -count=1 -failfast -timeout 10m ./internal/...
|
||||
go test -p=1 -count=1 -failfast -timeout 15m ./internal/...
|
||||
|
||||
lint:
|
||||
golangci-lint fmt && golangci-lint run
|
||||
golangci-lint fmt ./cmd/... ./internal/... && golangci-lint run ./cmd/... ./internal/...
|
||||
|
||||
migration-create:
|
||||
goose create $(name) sql
|
||||
|
||||
@@ -15,6 +15,8 @@ import (
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/disk"
|
||||
@@ -23,8 +25,10 @@ import (
|
||||
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/restores"
|
||||
"databasus-backend/internal/features/restores/restoring"
|
||||
"databasus-backend/internal/features/storages"
|
||||
system_healthcheck "databasus-backend/internal/features/system/healthcheck"
|
||||
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
users_controllers "databasus-backend/internal/features/users/controllers"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
@@ -54,16 +58,25 @@ func main() {
|
||||
log := logger.GetLogger()
|
||||
|
||||
cache_utils.TestCacheConnection()
|
||||
err := cache_utils.ClearAllCache()
|
||||
if err != nil {
|
||||
log.Error("Failed to clear cache", "error", err)
|
||||
os.Exit(1)
|
||||
|
||||
if config.GetEnv().IsPrimaryNode {
|
||||
log.Info("Clearing cache...")
|
||||
|
||||
err := cache_utils.ClearAllCache()
|
||||
if err != nil {
|
||||
log.Error("Failed to clear cache", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
runMigrations(log)
|
||||
if config.GetEnv().IsPrimaryNode {
|
||||
runMigrations(log)
|
||||
} else {
|
||||
log.Info("Skipping migrations (IS_PRIMARY_NODE is false)")
|
||||
}
|
||||
|
||||
// create directories that used for backups and restore
|
||||
err = files_utils.EnsureDirectories([]string{
|
||||
err := files_utils.EnsureDirectories([]string{
|
||||
config.GetEnv().TempFolder,
|
||||
config.GetEnv().DataFolder,
|
||||
})
|
||||
@@ -104,7 +117,9 @@ func main() {
|
||||
enableCors(ginApp)
|
||||
setUpRoutes(ginApp)
|
||||
setUpDependencies()
|
||||
|
||||
runBackgroundTasks(log)
|
||||
|
||||
mountFrontend(ginApp)
|
||||
|
||||
startServerWithGracefulShutdown(log, ginApp)
|
||||
@@ -170,6 +185,9 @@ func startServerWithGracefulShutdown(log *slog.Logger, app *gin.Engine) {
|
||||
<-quit
|
||||
log.Info("Shutdown signal received")
|
||||
|
||||
// Gracefully shutdown VictoriaLogs writer
|
||||
logger.ShutdownVictoriaLogs(5 * time.Second)
|
||||
|
||||
// The context is used to inform the server it has 10 seconds to finish
|
||||
// the request it is currently handling
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -227,35 +245,80 @@ func setUpDependencies() {
|
||||
notifiers.SetupDependencies()
|
||||
storages.SetupDependencies()
|
||||
backups_config.SetupDependencies()
|
||||
task_cancellation.SetupDependencies()
|
||||
}
|
||||
|
||||
func runBackgroundTasks(log *slog.Logger) {
|
||||
log.Info("Preparing to run background tasks...")
|
||||
|
||||
// Create context that will be cancelled on shutdown
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Set up signal handling for graceful shutdown
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, os.Interrupt, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-quit
|
||||
log.Info("Shutdown signal received, cancelling all background tasks")
|
||||
cancel()
|
||||
}()
|
||||
|
||||
err := files_utils.CleanFolder(config.GetEnv().TempFolder)
|
||||
if err != nil {
|
||||
log.Error("Failed to clean temp folder", "error", err)
|
||||
}
|
||||
|
||||
go runWithPanicLogging(log, "backup background service", func() {
|
||||
backups.GetBackupBackgroundService().Run()
|
||||
})
|
||||
if config.GetEnv().IsPrimaryNode {
|
||||
log.Info("Starting primary node background tasks...")
|
||||
|
||||
go runWithPanicLogging(log, "restore background service", func() {
|
||||
restores.GetRestoreBackgroundService().Run()
|
||||
})
|
||||
go runWithPanicLogging(log, "backup background service", func() {
|
||||
backuping.GetBackupsScheduler().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "healthcheck attempt background service", func() {
|
||||
healthcheck_attempt.GetHealthcheckAttemptBackgroundService().Run()
|
||||
})
|
||||
go runWithPanicLogging(log, "backup cleaner background service", func() {
|
||||
backuping.GetBackupCleaner().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "audit log cleanup background service", func() {
|
||||
audit_logs.GetAuditLogBackgroundService().Run()
|
||||
})
|
||||
go runWithPanicLogging(log, "restore background service", func() {
|
||||
restoring.GetRestoresScheduler().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "download token cleanup background service", func() {
|
||||
backups.GetDownloadTokenBackgroundService().Run()
|
||||
})
|
||||
go runWithPanicLogging(log, "healthcheck attempt background service", func() {
|
||||
healthcheck_attempt.GetHealthcheckAttemptBackgroundService().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "audit log cleanup background service", func() {
|
||||
audit_logs.GetAuditLogBackgroundService().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "download token cleanup background service", func() {
|
||||
backups_download.GetDownloadTokenBackgroundService().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "backup nodes registry background service", func() {
|
||||
backuping.GetBackupNodesRegistry().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "restore nodes registry background service", func() {
|
||||
restoring.GetRestoreNodesRegistry().Run(ctx)
|
||||
})
|
||||
} else {
|
||||
log.Info("Skipping primary node tasks as not primary node")
|
||||
}
|
||||
|
||||
if config.GetEnv().IsProcessingNode {
|
||||
log.Info("Starting backup node background tasks...")
|
||||
|
||||
go runWithPanicLogging(log, "backup node", func() {
|
||||
backuping.GetBackuperNode().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "restore node", func() {
|
||||
restoring.GetRestorerNode().Run(ctx)
|
||||
})
|
||||
} else {
|
||||
log.Info("Skipping backup/restore node tasks as not backup node")
|
||||
}
|
||||
}
|
||||
|
||||
func runWithPanicLogging(log *slog.Logger, serviceName string, fn func()) {
|
||||
|
||||
@@ -34,6 +34,20 @@ services:
|
||||
retries: 5
|
||||
start_period: 20s
|
||||
|
||||
# VictoriaLogs for external logging
|
||||
victoria-logs:
|
||||
image: victoriametrics/victoria-logs:latest
|
||||
container_name: victoria-logs
|
||||
ports:
|
||||
- "9428:9428"
|
||||
command:
|
||||
- -storageDataPath=/victoria-logs-data
|
||||
- -retentionPeriod=7d
|
||||
- -httpAuth.password=devpassword
|
||||
volumes:
|
||||
- ./victoria-logs-data:/victoria-logs-data
|
||||
restart: unless-stopped
|
||||
|
||||
# Test MinIO container
|
||||
test-minio:
|
||||
image: minio/minio:latest
|
||||
|
||||
@@ -28,7 +28,6 @@ require (
|
||||
github.com/valkey-io/valkey-go v1.0.70
|
||||
go.mongodb.org/mongo-driver v1.17.6
|
||||
golang.org/x/crypto v0.46.0
|
||||
golang.org/x/time v0.14.0
|
||||
gorm.io/driver/postgres v1.5.11
|
||||
gorm.io/gorm v1.26.1
|
||||
)
|
||||
@@ -186,6 +185,7 @@ require (
|
||||
go.yaml.in/yaml/v2 v2.4.3 // indirect
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||
golang.org/x/term v0.38.0 // indirect
|
||||
golang.org/x/time v0.14.0 // indirect
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
|
||||
gopkg.in/validator.v2 v2.0.1 // indirect
|
||||
moul.io/http2curl/v2 v2.3.0 // indirect
|
||||
|
||||
@@ -22,13 +22,32 @@ const (
|
||||
|
||||
type EnvVariables struct {
|
||||
IsTesting bool
|
||||
DatabaseDsn string `env:"DATABASE_DSN" required:"true"`
|
||||
EnvMode env_utils.EnvMode `env:"ENV_MODE" required:"true"`
|
||||
PostgresesInstallDir string `env:"POSTGRES_INSTALL_DIR"`
|
||||
MysqlInstallDir string `env:"MYSQL_INSTALL_DIR"`
|
||||
MariadbInstallDir string `env:"MARIADB_INSTALL_DIR"`
|
||||
MongodbInstallDir string `env:"MONGODB_INSTALL_DIR"`
|
||||
|
||||
// Internal database
|
||||
DatabaseDsn string `env:"DATABASE_DSN" required:"true"`
|
||||
// Internal Valkey
|
||||
ValkeyHost string `env:"VALKEY_HOST" required:"true"`
|
||||
ValkeyPort string `env:"VALKEY_PORT" required:"true"`
|
||||
ValkeyUsername string `env:"VALKEY_USERNAME" required:"true"`
|
||||
ValkeyPassword string `env:"VALKEY_PASSWORD" required:"true"`
|
||||
ValkeyIsSsl bool `env:"VALKEY_IS_SSL" required:"true"`
|
||||
|
||||
IsCloud bool `env:"IS_CLOUD"`
|
||||
TestLocalhost string `env:"TEST_LOCALHOST"`
|
||||
|
||||
ShowDbInstallationVerificationLogs bool `env:"SHOW_DB_INSTALLATION_VERIFICATION_LOGS"`
|
||||
IsSkipExternalResourcesTests bool `env:"IS_SKIP_EXTERNAL_RESOURCES_TESTS"`
|
||||
|
||||
IsManyNodesMode bool `env:"IS_MANY_NODES_MODE"`
|
||||
IsPrimaryNode bool `env:"IS_PRIMARY_NODE"`
|
||||
IsProcessingNode bool `env:"IS_PROCESSING_NODE"`
|
||||
NodeNetworkThroughputMBs int `env:"NODE_NETWORK_THROUGHPUT_MBPS"`
|
||||
|
||||
DataFolder string
|
||||
TempFolder string
|
||||
SecretKeyPath string
|
||||
@@ -79,19 +98,16 @@ type EnvVariables struct {
|
||||
TestMongodb70Port string `env:"TEST_MONGODB_70_PORT"`
|
||||
TestMongodb82Port string `env:"TEST_MONGODB_82_PORT"`
|
||||
|
||||
// Valkey
|
||||
ValkeyHost string `env:"VALKEY_HOST" required:"true"`
|
||||
ValkeyPort string `env:"VALKEY_PORT" required:"true"`
|
||||
ValkeyUsername string `env:"VALKEY_USERNAME"`
|
||||
ValkeyPassword string `env:"VALKEY_PASSWORD"`
|
||||
ValkeyIsSsl bool `env:"VALKEY_IS_SSL" required:"true"`
|
||||
|
||||
// oauth
|
||||
GitHubClientID string `env:"GITHUB_CLIENT_ID"`
|
||||
GitHubClientSecret string `env:"GITHUB_CLIENT_SECRET"`
|
||||
GoogleClientID string `env:"GOOGLE_CLIENT_ID"`
|
||||
GoogleClientSecret string `env:"GOOGLE_CLIENT_SECRET"`
|
||||
|
||||
// Cloudflare Turnstile
|
||||
CloudflareTurnstileSecretKey string `env:"CLOUDFLARE_TURNSTILE_SECRET_KEY"`
|
||||
CloudflareTurnstileSiteKey string `env:"CLOUDFLARE_TURNSTILE_SITE_KEY"`
|
||||
|
||||
// testing Telegram
|
||||
TestTelegramBotToken string `env:"TEST_TELEGRAM_BOT_TOKEN"`
|
||||
TestTelegramChatID string `env:"TEST_TELEGRAM_CHAT_ID"`
|
||||
@@ -102,6 +118,15 @@ type EnvVariables struct {
|
||||
TestSupabaseUsername string `env:"TEST_SUPABASE_USERNAME"`
|
||||
TestSupabasePassword string `env:"TEST_SUPABASE_PASSWORD"`
|
||||
TestSupabaseDatabase string `env:"TEST_SUPABASE_DATABASE"`
|
||||
|
||||
// SMTP configuration (optional)
|
||||
SMTPHost string `env:"SMTP_HOST"`
|
||||
SMTPPort int `env:"SMTP_PORT"`
|
||||
SMTPUser string `env:"SMTP_USER"`
|
||||
SMTPPassword string `env:"SMTP_PASSWORD"`
|
||||
|
||||
// Application URL (optional) - used for email links
|
||||
DatabasusURL string `env:"DATABASUS_URL"`
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -162,6 +187,21 @@ func loadEnvVariables() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Set default value for ShowDbInstallationVerificationLogs if not defined
|
||||
if os.Getenv("SHOW_DB_INSTALLATION_VERIFICATION_LOGS") == "" {
|
||||
env.ShowDbInstallationVerificationLogs = true
|
||||
}
|
||||
|
||||
// Set default value for IsSkipExternalTests if not defined
|
||||
if os.Getenv("IS_SKIP_EXTERNAL_RESOURCES_TESTS") == "" {
|
||||
env.IsSkipExternalResourcesTests = false
|
||||
}
|
||||
|
||||
// Set default value for IsCloud if not defined
|
||||
if os.Getenv("IS_CLOUD") == "" {
|
||||
env.IsCloud = false
|
||||
}
|
||||
|
||||
for _, arg := range os.Args {
|
||||
if strings.Contains(arg, "test") {
|
||||
env.IsTesting = true
|
||||
@@ -169,6 +209,14 @@ func loadEnvVariables() {
|
||||
}
|
||||
}
|
||||
|
||||
// Check for external database override
|
||||
if externalDsn := os.Getenv("DANGEROUS_EXTERNAL_DATABASE_DSN"); externalDsn != "" {
|
||||
log.Warn(
|
||||
"Using DANGEROUS_EXTERNAL_DATABASE_DSN - connecting to external database instead of internal PostgreSQL",
|
||||
)
|
||||
env.DatabaseDsn = externalDsn
|
||||
}
|
||||
|
||||
if env.DatabaseDsn == "" {
|
||||
log.Error("DATABASE_DSN is empty")
|
||||
os.Exit(1)
|
||||
@@ -185,16 +233,49 @@ func loadEnvVariables() {
|
||||
log.Info("ENV_MODE loaded", "mode", env.EnvMode)
|
||||
|
||||
env.PostgresesInstallDir = filepath.Join(backendRoot, "tools", "postgresql")
|
||||
tools.VerifyPostgresesInstallation(log, env.EnvMode, env.PostgresesInstallDir)
|
||||
tools.VerifyPostgresesInstallation(
|
||||
log,
|
||||
env.EnvMode,
|
||||
env.PostgresesInstallDir,
|
||||
env.ShowDbInstallationVerificationLogs,
|
||||
)
|
||||
|
||||
env.MysqlInstallDir = filepath.Join(backendRoot, "tools", "mysql")
|
||||
tools.VerifyMysqlInstallation(log, env.EnvMode, env.MysqlInstallDir)
|
||||
tools.VerifyMysqlInstallation(
|
||||
log,
|
||||
env.EnvMode,
|
||||
env.MysqlInstallDir,
|
||||
env.ShowDbInstallationVerificationLogs,
|
||||
)
|
||||
|
||||
env.MariadbInstallDir = filepath.Join(backendRoot, "tools", "mariadb")
|
||||
tools.VerifyMariadbInstallation(log, env.EnvMode, env.MariadbInstallDir)
|
||||
tools.VerifyMariadbInstallation(
|
||||
log,
|
||||
env.EnvMode,
|
||||
env.MariadbInstallDir,
|
||||
env.ShowDbInstallationVerificationLogs,
|
||||
)
|
||||
|
||||
env.MongodbInstallDir = filepath.Join(backendRoot, "tools", "mongodb")
|
||||
tools.VerifyMongodbInstallation(log, env.EnvMode, env.MongodbInstallDir)
|
||||
tools.VerifyMongodbInstallation(
|
||||
log,
|
||||
env.EnvMode,
|
||||
env.MongodbInstallDir,
|
||||
env.ShowDbInstallationVerificationLogs,
|
||||
)
|
||||
|
||||
if env.NodeNetworkThroughputMBs == 0 {
|
||||
env.NodeNetworkThroughputMBs = 125 // 1 Gbit/s
|
||||
}
|
||||
|
||||
if !env.IsManyNodesMode {
|
||||
env.IsPrimaryNode = true
|
||||
env.IsProcessingNode = true
|
||||
}
|
||||
|
||||
if env.TestLocalhost == "" {
|
||||
env.TestLocalhost = "localhost"
|
||||
}
|
||||
|
||||
// Valkey
|
||||
if env.ValkeyHost == "" {
|
||||
@@ -206,6 +287,27 @@ func loadEnvVariables() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Check for external Valkey override
|
||||
if externalValkeyHost := os.Getenv("DANGEROUS_VALKEY_HOST"); externalValkeyHost != "" {
|
||||
log.Warn(
|
||||
"Using DANGEROUS_VALKEY_* variables - connecting to external Valkey instead of internal instance",
|
||||
)
|
||||
env.ValkeyHost = externalValkeyHost
|
||||
|
||||
if externalValkeyPort := os.Getenv("DANGEROUS_VALKEY_PORT"); externalValkeyPort != "" {
|
||||
env.ValkeyPort = externalValkeyPort
|
||||
}
|
||||
if externalValkeyUsername := os.Getenv("DANGEROUS_VALKEY_USERNAME"); externalValkeyUsername != "" {
|
||||
env.ValkeyUsername = externalValkeyUsername
|
||||
}
|
||||
if externalValkeyPassword := os.Getenv("DANGEROUS_VALKEY_PASSWORD"); externalValkeyPassword != "" {
|
||||
env.ValkeyPassword = externalValkeyPassword
|
||||
}
|
||||
if externalValkeyIsSsl := os.Getenv("DANGEROUS_VALKEY_IS_SSL"); externalValkeyIsSsl != "" {
|
||||
env.ValkeyIsSsl = externalValkeyIsSsl == "true"
|
||||
}
|
||||
}
|
||||
|
||||
// Store the data and temp folders one level below the root
|
||||
// (projectRoot/databasus-data -> /databasus-data)
|
||||
env.DataFolder = filepath.Join(filepath.Dir(backendRoot), "databasus-data", "backups")
|
||||
|
||||
@@ -1,33 +1,51 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type AuditLogBackgroundService struct {
|
||||
auditLogService *AuditLogService
|
||||
logger *slog.Logger
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *AuditLogBackgroundService) Run() {
|
||||
s.logger.Info("Starting audit log cleanup background service")
|
||||
func (s *AuditLogBackgroundService) Run(ctx context.Context) {
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
return
|
||||
}
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
|
||||
for {
|
||||
if config.IsShouldShutdown() {
|
||||
s.logger.Info("Starting audit log cleanup background service")
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.cleanOldAuditLogs(); err != nil {
|
||||
s.logger.Error("Failed to clean old audit logs", "error", err)
|
||||
}
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
time.Sleep(1 * time.Hour)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.cleanOldAuditLogs(); err != nil {
|
||||
s.logger.Error("Failed to clean old audit logs", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
@@ -14,8 +17,10 @@ var auditLogController = &AuditLogController{
|
||||
auditLogService,
|
||||
}
|
||||
var auditLogBackgroundService = &AuditLogBackgroundService{
|
||||
auditLogService,
|
||||
logger.GetLogger(),
|
||||
auditLogService: auditLogService,
|
||||
logger: logger.GetLogger(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
|
||||
func GetAuditLogService() *AuditLogService {
|
||||
@@ -30,8 +35,23 @@ func GetAuditLogBackgroundService() *AuditLogBackgroundService {
|
||||
return auditLogBackgroundService
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
users_services.GetUserService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
users_services.GetUserService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,254 +0,0 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/storages"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/period"
|
||||
"log/slog"
|
||||
"time"
|
||||
)
|
||||
|
||||
type BackupBackgroundService struct {
|
||||
backupService *BackupService
|
||||
backupRepository *BackupRepository
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
storageService *storages.StorageService
|
||||
|
||||
lastBackupTime time.Time
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *BackupBackgroundService) Run() {
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
|
||||
if err := s.failBackupsInProgress(); err != nil {
|
||||
s.logger.Error("Failed to fail backups in progress", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
if config.IsShouldShutdown() {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.cleanOldBackups(); err != nil {
|
||||
s.logger.Error("Failed to clean old backups", "error", err)
|
||||
}
|
||||
|
||||
if err := s.runPendingBackups(); err != nil {
|
||||
s.logger.Error("Failed to run pending backups", "error", err)
|
||||
}
|
||||
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
time.Sleep(1 * time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupBackgroundService) IsBackupsWorkerRunning() bool {
|
||||
// if last backup time is more than 5 minutes ago, return false
|
||||
return s.lastBackupTime.After(time.Now().UTC().Add(-5 * time.Minute))
|
||||
}
|
||||
|
||||
func (s *BackupBackgroundService) failBackupsInProgress() error {
|
||||
backupsInProgress, err := s.backupRepository.FindByStatus(BackupStatusInProgress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backup := range backupsInProgress {
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(backup.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
failMessage := "Backup failed due to application restart"
|
||||
backup.FailMessage = &failMessage
|
||||
backup.Status = BackupStatusFailed
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
s.backupService.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupFailed,
|
||||
&failMessage,
|
||||
)
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupBackgroundService) cleanOldBackups() error {
|
||||
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
backupStorePeriod := backupConfig.StorePeriod
|
||||
|
||||
if backupStorePeriod == period.PeriodForever {
|
||||
continue
|
||||
}
|
||||
|
||||
storeDuration := backupStorePeriod.ToDuration()
|
||||
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
|
||||
|
||||
oldBackups, err := s.backupRepository.FindBackupsBeforeDate(
|
||||
backupConfig.DatabaseID,
|
||||
dateBeforeBackupsShouldBeDeleted,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to find old backups for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, backup := range oldBackups {
|
||||
storage, err := s.storageService.GetStorageByID(backup.StorageID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to get storage by ID",
|
||||
"storageId",
|
||||
backup.StorageID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
err = storage.DeleteFile(encryptor, backup.ID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to delete backup file", "backupId", backup.ID, "error", err)
|
||||
}
|
||||
|
||||
if err := s.backupRepository.DeleteByID(backup.ID); err != nil {
|
||||
s.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
s.logger.Info(
|
||||
"Deleted old backup",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupBackgroundService) runPendingBackups() error {
|
||||
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
if backupConfig.BackupInterval == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
lastBackup, err := s.backupRepository.FindLastByDatabaseID(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to get last backup for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
var lastBackupTime *time.Time
|
||||
if lastBackup != nil {
|
||||
lastBackupTime = &lastBackup.CreatedAt
|
||||
}
|
||||
|
||||
remainedBackupTryCount := s.GetRemainedBackupTryCount(lastBackup)
|
||||
|
||||
if backupConfig.BackupInterval.ShouldTriggerBackup(time.Now().UTC(), lastBackupTime) ||
|
||||
remainedBackupTryCount > 0 {
|
||||
s.logger.Info(
|
||||
"Triggering scheduled backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"intervalType",
|
||||
backupConfig.BackupInterval.Interval,
|
||||
)
|
||||
|
||||
go s.backupService.MakeBackup(backupConfig.DatabaseID, remainedBackupTryCount == 1)
|
||||
s.logger.Info(
|
||||
"Successfully triggered scheduled backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRemainedBackupTryCount returns the number of remaining backup tries for a given backup.
|
||||
// If the backup is not failed or the backup config does not allow retries, it returns 0.
|
||||
// If the backup is failed and the backup config allows retries, it returns the number of remaining tries.
|
||||
// If the backup is failed and the backup config does not allow retries, it returns 0.
|
||||
func (s *BackupBackgroundService) GetRemainedBackupTryCount(lastBackup *Backup) int {
|
||||
if lastBackup == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
if lastBackup.Status != BackupStatusFailed {
|
||||
return 0
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(lastBackup.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
if !backupConfig.IsRetryIfFailed {
|
||||
return 0
|
||||
}
|
||||
|
||||
maxFailedTriesCount := backupConfig.MaxFailedTriesCount
|
||||
|
||||
lastBackups, err := s.backupRepository.FindByDatabaseIDWithLimit(
|
||||
lastBackup.DatabaseID,
|
||||
maxFailedTriesCount,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to find last backups by database ID", "error", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
lastFailedBackups := make([]*Backup, 0)
|
||||
|
||||
for _, backup := range lastBackups {
|
||||
if backup.Status == BackupStatusFailed {
|
||||
lastFailedBackups = append(lastFailedBackups, backup)
|
||||
}
|
||||
}
|
||||
|
||||
return maxFailedTriesCount - len(lastFailedBackups)
|
||||
}
|
||||
@@ -1,407 +0,0 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/util/period"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Enable backups for the database
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add old backup
|
||||
backupRepository.Save(&Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusCompleted,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
|
||||
// Wait for backup to complete (runs in goroutine)
|
||||
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
|
||||
|
||||
// assertions
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 2)
|
||||
|
||||
// Wait for any cleanup operations to complete before defer cleanup runs
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Enable backups for the database
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add recent backup (1 hour ago)
|
||||
backupRepository.Save(&Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusCompleted,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// assertions
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1) // Should still be 1 backup, no new backup created
|
||||
|
||||
// Wait for any cleanup operations to complete before defer cleanup runs
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T) {
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Enable backups for the database with retries disabled
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
backupConfig.IsRetryIfFailed = false
|
||||
backupConfig.MaxFailedTriesCount = 0
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add failed backup
|
||||
failMessage := "backup failed"
|
||||
backupRepository.Save(&Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// assertions
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1) // Should still be 1 backup, no retry attempted
|
||||
|
||||
// Wait for any cleanup operations to complete before defer cleanup runs
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Enable backups for the database with retries enabled
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
backupConfig.IsRetryIfFailed = true
|
||||
backupConfig.MaxFailedTriesCount = 3
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add failed backup
|
||||
failMessage := "backup failed"
|
||||
backupRepository.Save(&Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
|
||||
// Wait for backup to complete (runs in goroutine)
|
||||
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
|
||||
|
||||
// assertions
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 2) // Should have 2 backups, retry was attempted
|
||||
|
||||
// Wait for any cleanup operations to complete before defer cleanup runs
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *testing.T) {
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Enable backups for the database with retries enabled
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
backupConfig.IsRetryIfFailed = true
|
||||
backupConfig.MaxFailedTriesCount = 3
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
failMessage := "backup failed"
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
backupRepository.Save(&Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
})
|
||||
}
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// assertions
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 3) // Should have 3 backups, not more than max
|
||||
|
||||
// Wait for any cleanup operations to complete before defer cleanup runs
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_MakeBackgroundBackupWhenBakupsDisabled_BackupSkipped(t *testing.T) {
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = false
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add old backup that would trigger new backup if enabled
|
||||
backupRepository.Save(&Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusCompleted,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1)
|
||||
|
||||
// Wait for any cleanup operations to complete before defer cleanup runs
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
@@ -1,84 +0,0 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"context"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/logger"
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const backupCancelChannel = "backup:cancel"
|
||||
|
||||
type BackupContextManager struct {
|
||||
mu sync.RWMutex
|
||||
cancelFuncs map[uuid.UUID]context.CancelFunc
|
||||
pubsub *cache_utils.PubSubManager
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewBackupContextManager() *BackupContextManager {
|
||||
return &BackupContextManager{
|
||||
cancelFuncs: make(map[uuid.UUID]context.CancelFunc),
|
||||
pubsub: cache_utils.NewPubSubManager(),
|
||||
logger: logger.GetLogger(),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) StartSubscription() {
|
||||
ctx := context.Background()
|
||||
|
||||
handler := func(message string) {
|
||||
backupID, err := uuid.Parse(message)
|
||||
if err != nil {
|
||||
m.logger.Error("Invalid backup ID in cancel message", "message", message, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
cancelFunc, exists := m.cancelFuncs[backupID]
|
||||
if exists {
|
||||
cancelFunc()
|
||||
delete(m.cancelFuncs, backupID)
|
||||
m.logger.Info("Cancelled backup via Pub/Sub", "backupID", backupID)
|
||||
}
|
||||
}
|
||||
|
||||
err := m.pubsub.Subscribe(ctx, backupCancelChannel, handler)
|
||||
if err != nil {
|
||||
m.logger.Error("Failed to subscribe to backup cancel channel", "error", err)
|
||||
} else {
|
||||
m.logger.Info("Successfully subscribed to backup cancel channel")
|
||||
}
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) RegisterBackup(backupID uuid.UUID, cancelFunc context.CancelFunc) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.cancelFuncs[backupID] = cancelFunc
|
||||
m.logger.Debug("Registered backup", "backupID", backupID)
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) CancelBackup(backupID uuid.UUID) error {
|
||||
ctx := context.Background()
|
||||
|
||||
err := m.pubsub.Publish(ctx, backupCancelChannel, backupID.String())
|
||||
if err != nil {
|
||||
m.logger.Error("Failed to publish cancel message", "backupID", backupID, "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
m.logger.Info("Published backup cancel message", "backupID", backupID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) UnregisterBackup(backupID uuid.UUID) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.cancelFuncs, backupID)
|
||||
m.logger.Debug("Unregistered backup", "backupID", backupID)
|
||||
}
|
||||
427
backend/internal/features/backups/backups/backuping/backuper.go
Normal file
427
backend/internal/features/backups/backups/backuping/backuper.go
Normal file
@@ -0,0 +1,427 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/storages"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
)
|
||||
|
||||
const (
|
||||
heartbeatTickerInterval = 15 * time.Second
|
||||
backuperHeathcheckThreshold = 5 * time.Minute
|
||||
)
|
||||
|
||||
type BackuperNode struct {
|
||||
databaseService *databases.DatabaseService
|
||||
fieldEncryptor util_encryption.FieldEncryptor
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
backupRepository *backups_core.BackupRepository
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
storageService *storages.StorageService
|
||||
notificationSender backups_core.NotificationSender
|
||||
backupCancelManager *tasks_cancellation.TaskCancelManager
|
||||
backupNodesRegistry *BackupNodesRegistry
|
||||
logger *slog.Logger
|
||||
createBackupUseCase backups_core.CreateBackupUsecase
|
||||
nodeID uuid.UUID
|
||||
|
||||
lastHeartbeat time.Time
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (n *BackuperNode) Run(ctx context.Context) {
|
||||
wasAlreadyRun := n.hasRun.Load()
|
||||
|
||||
n.runOnce.Do(func() {
|
||||
n.hasRun.Store(true)
|
||||
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
|
||||
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
|
||||
|
||||
backupNode := BackupNode{
|
||||
ID: n.nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: time.Now().UTC(),
|
||||
}
|
||||
|
||||
if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
|
||||
n.logger.Error("Failed to register node in registry", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
|
||||
go func() {
|
||||
n.MakeBackup(backupID, isCallNotifier)
|
||||
if err := n.backupNodesRegistry.PublishBackupCompletion(n.nodeID, backupID); err != nil {
|
||||
n.logger.Error(
|
||||
"Failed to publish backup completion",
|
||||
"error",
|
||||
err,
|
||||
"backupID",
|
||||
backupID,
|
||||
)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
err := n.backupNodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID, backupHandler)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := n.backupNodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
|
||||
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(heartbeatTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
|
||||
|
||||
if err := n.backupNodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
|
||||
n.logger.Error("Failed to unregister node from registry", "error", err)
|
||||
}
|
||||
|
||||
return
|
||||
case <-ticker.C:
|
||||
n.sendHeartbeat(&backupNode)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", n))
|
||||
}
|
||||
}
|
||||
|
||||
func (n *BackuperNode) IsBackuperRunning() bool {
|
||||
return n.lastHeartbeat.After(time.Now().UTC().Add(-backuperHeathcheckThreshold))
|
||||
}
|
||||
|
||||
func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
|
||||
backup, err := n.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get backup by ID", "backupId", backupID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
databaseID := backup.DatabaseID
|
||||
|
||||
database, err := n.databaseService.GetDatabaseByID(databaseID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get database by ID", "databaseId", databaseID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
backupConfig, err := n.backupConfigService.GetBackupConfigByDbId(databaseID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if backupConfig.StorageID == nil {
|
||||
n.logger.Error("Backup config storage ID is not defined")
|
||||
return
|
||||
}
|
||||
|
||||
storage, err := n.storageService.GetStorageByID(*backupConfig.StorageID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get storage by ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now().UTC()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
n.backupCancelManager.RegisterTask(backup.ID, cancel)
|
||||
defer n.backupCancelManager.UnregisterTask(backup.ID)
|
||||
|
||||
backupProgressListener := func(
|
||||
completedMBs float64,
|
||||
) {
|
||||
backup.BackupSizeMb = completedMBs
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
// Check size limit (0 = unlimited)
|
||||
if backupConfig.MaxBackupSizeMB > 0 &&
|
||||
completedMBs > float64(backupConfig.MaxBackupSizeMB) {
|
||||
errMsg := fmt.Sprintf(
|
||||
"backup size (%.2f MB) exceeded maximum allowed size (%d MB)",
|
||||
completedMBs,
|
||||
backupConfig.MaxBackupSizeMB,
|
||||
)
|
||||
|
||||
backup.Status = backups_core.BackupStatusFailed
|
||||
backup.IsSkipRetry = true
|
||||
backup.FailMessage = &errMsg
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to save backup with size exceeded error", "error", err)
|
||||
}
|
||||
cancel() // Cancel the backup context
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to update backup progress", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
backupMetadata, err := n.createBackupUseCase.Execute(
|
||||
ctx,
|
||||
backup,
|
||||
backupConfig,
|
||||
database,
|
||||
storage,
|
||||
backupProgressListener,
|
||||
)
|
||||
if err != nil {
|
||||
// Check if backup was already marked as failed by progress listener (e.g., size limit exceeded)
|
||||
// If so, skip error handling to avoid overwriting the status
|
||||
currentBackup, fetchErr := n.backupRepository.FindByID(backup.ID)
|
||||
if fetchErr == nil && currentBackup.Status == backups_core.BackupStatusFailed {
|
||||
n.logger.Warn(
|
||||
"Backup already marked as failed by progress listener, skipping error handling",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"failMessage",
|
||||
*currentBackup.FailMessage,
|
||||
)
|
||||
|
||||
// Still call notification for size limit failures
|
||||
n.SendBackupNotification(
|
||||
backupConfig,
|
||||
currentBackup,
|
||||
backups_config.NotificationBackupFailed,
|
||||
currentBackup.FailMessage,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
errMsg := err.Error()
|
||||
|
||||
// Log detailed error information for debugging
|
||||
n.logger.Error("Backup execution failed",
|
||||
"backupId", backup.ID,
|
||||
"databaseId", databaseID,
|
||||
"databaseType", database.Type,
|
||||
"storageId", storage.ID,
|
||||
"storageType", storage.Type,
|
||||
"error", err,
|
||||
"errorMessage", errMsg,
|
||||
)
|
||||
|
||||
// Check if backup was cancelled (not due to shutdown)
|
||||
isCancelled := strings.Contains(errMsg, "backup cancelled") ||
|
||||
strings.Contains(errMsg, "context canceled") ||
|
||||
errors.Is(err, context.Canceled)
|
||||
isShutdown := strings.Contains(errMsg, "shutdown")
|
||||
|
||||
if isCancelled && !isShutdown {
|
||||
n.logger.Warn("Backup was cancelled by user or system",
|
||||
"backupId", backup.ID,
|
||||
"isCancelled", isCancelled,
|
||||
"isShutdown", isShutdown,
|
||||
)
|
||||
|
||||
backup.Status = backups_core.BackupStatusCanceled
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to save cancelled backup", "error", err)
|
||||
}
|
||||
|
||||
// Delete partial backup from storage
|
||||
storage, storageErr := n.storageService.GetStorageByID(backup.StorageID)
|
||||
if storageErr == nil {
|
||||
if deleteErr := storage.DeleteFile(n.fieldEncryptor, backup.ID.String()); deleteErr != nil {
|
||||
n.logger.Error(
|
||||
"Failed to delete partial backup file",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
deleteErr,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
backup.FailMessage = &errMsg
|
||||
backup.Status = backups_core.BackupStatusFailed
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
if updateErr := n.databaseService.SetBackupError(databaseID, errMsg); updateErr != nil {
|
||||
n.logger.Error(
|
||||
"Failed to update database last backup time",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"error",
|
||||
updateErr,
|
||||
)
|
||||
}
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to save backup", "error", err)
|
||||
}
|
||||
|
||||
n.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupFailed,
|
||||
&errMsg,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
backup.Status = backups_core.BackupStatusCompleted
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
// Update backup with encryption metadata if provided
|
||||
if backupMetadata != nil {
|
||||
backup.EncryptionSalt = backupMetadata.EncryptionSalt
|
||||
backup.EncryptionIV = backupMetadata.EncryptionIV
|
||||
backup.Encryption = backupMetadata.Encryption
|
||||
}
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to save backup", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Update database last backup time
|
||||
now := time.Now().UTC()
|
||||
if updateErr := n.databaseService.SetLastBackupTime(databaseID, now); updateErr != nil {
|
||||
n.logger.Error(
|
||||
"Failed to update database last backup time",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"error",
|
||||
updateErr,
|
||||
)
|
||||
}
|
||||
|
||||
if backup.Status != backups_core.BackupStatusCompleted && !isCallNotifier {
|
||||
return
|
||||
}
|
||||
|
||||
n.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupSuccess,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
func (n *BackuperNode) SendBackupNotification(
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
backup *backups_core.Backup,
|
||||
notificationType backups_config.BackupNotificationType,
|
||||
errorMessage *string,
|
||||
) {
|
||||
database, err := n.databaseService.GetDatabaseByID(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
workspace, err := n.workspaceService.GetWorkspaceByID(*database.WorkspaceID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, notifier := range database.Notifiers {
|
||||
if !slices.Contains(
|
||||
backupConfig.SendNotificationsOn,
|
||||
notificationType,
|
||||
) {
|
||||
continue
|
||||
}
|
||||
|
||||
title := ""
|
||||
switch notificationType {
|
||||
case backups_config.NotificationBackupFailed:
|
||||
title = fmt.Sprintf(
|
||||
"❌ Backup failed for database \"%s\" (workspace \"%s\")",
|
||||
database.Name,
|
||||
workspace.Name,
|
||||
)
|
||||
case backups_config.NotificationBackupSuccess:
|
||||
title = fmt.Sprintf(
|
||||
"✅ Backup completed for database \"%s\" (workspace \"%s\")",
|
||||
database.Name,
|
||||
workspace.Name,
|
||||
)
|
||||
}
|
||||
|
||||
message := ""
|
||||
if errorMessage != nil {
|
||||
message = *errorMessage
|
||||
} else {
|
||||
// Format size conditionally
|
||||
var sizeStr string
|
||||
if backup.BackupSizeMb < 1024 {
|
||||
sizeStr = fmt.Sprintf("%.2f MB", backup.BackupSizeMb)
|
||||
} else {
|
||||
sizeGB := backup.BackupSizeMb / 1024
|
||||
sizeStr = fmt.Sprintf("%.2f GB", sizeGB)
|
||||
}
|
||||
|
||||
// Format duration as "0m 0s 0ms"
|
||||
totalMs := backup.BackupDurationMs
|
||||
minutes := totalMs / (1000 * 60)
|
||||
seconds := (totalMs % (1000 * 60)) / 1000
|
||||
durationStr := fmt.Sprintf("%dm %ds", minutes, seconds)
|
||||
|
||||
message = fmt.Sprintf(
|
||||
"Backup completed successfully in %s.\nCompressed backup size: %s",
|
||||
durationStr,
|
||||
sizeStr,
|
||||
)
|
||||
}
|
||||
|
||||
n.notificationSender.SendNotification(
|
||||
¬ifier,
|
||||
title,
|
||||
message,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (n *BackuperNode) sendHeartbeat(backupNode *BackupNode) {
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *backupNode); err != nil {
|
||||
n.logger.Error("Failed to send heartbeat", "error", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,273 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
t.Run("BackupFailed_FailNotificationSent", func(t *testing.T) {
|
||||
mockNotificationSender := &MockNotificationSender{}
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.notificationSender = mockNotificationSender
|
||||
backuperNode.createBackupUseCase = &CreateFailedBackupUsecase{}
|
||||
|
||||
// Create a backup record directly that will be looked up by MakeBackup
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err := backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Set up expectations
|
||||
mockNotificationSender.On("SendNotification",
|
||||
mock.Anything,
|
||||
mock.MatchedBy(func(title string) bool {
|
||||
return strings.Contains(title, "❌ Backup failed")
|
||||
}),
|
||||
mock.MatchedBy(func(message string) bool {
|
||||
return strings.Contains(message, "backup failed")
|
||||
}),
|
||||
).Once()
|
||||
|
||||
backuperNode.MakeBackup(backup.ID, true)
|
||||
|
||||
// Verify all expectations were met
|
||||
mockNotificationSender.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("BackupSuccess_SuccessNotificationSent", func(t *testing.T) {
|
||||
mockNotificationSender := &MockNotificationSender{}
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.notificationSender = mockNotificationSender
|
||||
backuperNode.createBackupUseCase = &CreateSuccessBackupUsecase{}
|
||||
|
||||
// Create a backup record directly that will be looked up by MakeBackup
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err := backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Set up expectations
|
||||
mockNotificationSender.On("SendNotification",
|
||||
mock.Anything,
|
||||
mock.MatchedBy(func(title string) bool {
|
||||
return strings.Contains(title, "✅ Backup completed")
|
||||
}),
|
||||
mock.MatchedBy(func(message string) bool {
|
||||
return strings.Contains(message, "Backup completed successfully")
|
||||
}),
|
||||
).Once()
|
||||
|
||||
backuperNode.MakeBackup(backup.ID, true)
|
||||
|
||||
// Verify all expectations were met
|
||||
mockNotificationSender.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("BackupSuccess_VerifyNotificationContent", func(t *testing.T) {
|
||||
mockNotificationSender := &MockNotificationSender{}
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.notificationSender = mockNotificationSender
|
||||
backuperNode.createBackupUseCase = &CreateSuccessBackupUsecase{}
|
||||
|
||||
// Create a backup record directly that will be looked up by MakeBackup
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err := backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// capture arguments
|
||||
var capturedNotifier *notifiers.Notifier
|
||||
var capturedTitle string
|
||||
var capturedMessage string
|
||||
|
||||
mockNotificationSender.On("SendNotification",
|
||||
mock.Anything,
|
||||
mock.AnythingOfType("string"),
|
||||
mock.AnythingOfType("string"),
|
||||
).Run(func(args mock.Arguments) {
|
||||
capturedNotifier = args.Get(0).(*notifiers.Notifier)
|
||||
capturedTitle = args.Get(1).(string)
|
||||
capturedMessage = args.Get(2).(string)
|
||||
}).Once()
|
||||
|
||||
backuperNode.MakeBackup(backup.ID, true)
|
||||
|
||||
// Verify expectations were met
|
||||
mockNotificationSender.AssertExpectations(t)
|
||||
|
||||
// Additional detailed assertions
|
||||
assert.Contains(t, capturedTitle, "✅ Backup completed")
|
||||
assert.Contains(t, capturedTitle, database.Name)
|
||||
assert.Contains(t, capturedMessage, "Backup completed successfully")
|
||||
assert.Contains(t, capturedMessage, "10.00 MB")
|
||||
assert.Equal(t, notifier.ID, capturedNotifier.ID)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_BackupSizeLimits(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
t.Run("UnlimitedSize_MaxBackupSizeMBIsZero_BackupCompletes", func(t *testing.T) {
|
||||
// Enable backups with unlimited size (0)
|
||||
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
backupConfig.MaxBackupSizeMB = 0 // unlimited
|
||||
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.createBackupUseCase = &CreateLargeBackupUsecase{}
|
||||
|
||||
// Create a backup record
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode.MakeBackup(backup.ID, false)
|
||||
|
||||
// Verify backup completed successfully even with large size
|
||||
updatedBackup, err := backupRepository.FindByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, updatedBackup.Status)
|
||||
assert.Equal(t, float64(10000), updatedBackup.BackupSizeMb)
|
||||
assert.Nil(t, updatedBackup.FailMessage)
|
||||
})
|
||||
|
||||
t.Run("SizeExceeded_BackupFailedWithIsSkipRetry", func(t *testing.T) {
|
||||
// Enable backups with 5 MB limit
|
||||
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
backupConfig.MaxBackupSizeMB = 5
|
||||
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.createBackupUseCase = &CreateProgressiveBackupUsecase{}
|
||||
|
||||
// Create a backup record
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode.MakeBackup(backup.ID, false)
|
||||
|
||||
// Verify backup was marked as failed with IsSkipRetry=true
|
||||
updatedBackup, err := backupRepository.FindByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusFailed, updatedBackup.Status)
|
||||
assert.True(t, updatedBackup.IsSkipRetry)
|
||||
assert.NotNil(t, updatedBackup.FailMessage)
|
||||
assert.Contains(t, *updatedBackup.FailMessage, "exceeded maximum allowed size")
|
||||
assert.Contains(t, *updatedBackup.FailMessage, "10.00 MB")
|
||||
assert.Contains(t, *updatedBackup.FailMessage, "5 MB")
|
||||
assert.Greater(t, updatedBackup.BackupSizeMb, float64(5))
|
||||
})
|
||||
|
||||
t.Run("SizeWithinLimit_BackupCompletes", func(t *testing.T) {
|
||||
// Enable backups with 100 MB limit
|
||||
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
backupConfig.MaxBackupSizeMB = 100
|
||||
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.createBackupUseCase = &CreateMediumBackupUsecase{}
|
||||
|
||||
// Create a backup record
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode.MakeBackup(backup.ID, false)
|
||||
|
||||
// Verify backup completed successfully
|
||||
updatedBackup, err := backupRepository.FindByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, updatedBackup.Status)
|
||||
assert.Equal(t, float64(50), updatedBackup.BackupSizeMb)
|
||||
assert.Nil(t, updatedBackup.FailMessage)
|
||||
})
|
||||
}
|
||||
242
backend/internal/features/backups/backups/backuping/cleaner.go
Normal file
242
backend/internal/features/backups/backups/backuping/cleaner.go
Normal file
@@ -0,0 +1,242 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/storages"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/period"
|
||||
)
|
||||
|
||||
const (
|
||||
cleanerTickerInterval = 1 * time.Minute
|
||||
)
|
||||
|
||||
type BackupCleaner struct {
|
||||
backupRepository *backups_core.BackupRepository
|
||||
storageService *storages.StorageService
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
fieldEncryptor util_encryption.FieldEncryptor
|
||||
logger *slog.Logger
|
||||
backupRemoveListeners []backups_core.BackupRemoveListener
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) Run(ctx context.Context) {
|
||||
wasAlreadyRun := c.hasRun.Load()
|
||||
|
||||
c.runOnce.Do(func() {
|
||||
c.hasRun.Store(true)
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(cleanerTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := c.cleanOldBackups(); err != nil {
|
||||
c.logger.Error("Failed to clean old backups", "error", err)
|
||||
}
|
||||
|
||||
if err := c.cleanExceededBackups(); err != nil {
|
||||
c.logger.Error("Failed to clean exceeded backups", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", c))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) DeleteBackup(backup *backups_core.Backup) error {
|
||||
for _, listener := range c.backupRemoveListeners {
|
||||
if err := listener.OnBeforeBackupRemove(backup); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
storage, err := c.storageService.GetStorageByID(backup.StorageID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = storage.DeleteFile(c.fieldEncryptor, backup.ID.String())
|
||||
if err != nil {
|
||||
// we do not return error here, because sometimes clean up performed
|
||||
// before unavailable storage removal or change - therefore we should
|
||||
// proceed even in case of error. It's possible that some S3 or
|
||||
// storage is not available yet, it should not block us
|
||||
c.logger.Error("Failed to delete backup file", "error", err)
|
||||
}
|
||||
|
||||
return c.backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) AddBackupRemoveListener(listener backups_core.BackupRemoveListener) {
|
||||
c.backupRemoveListeners = append(c.backupRemoveListeners, listener)
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanOldBackups() error {
|
||||
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
backupStorePeriod := backupConfig.StorePeriod
|
||||
|
||||
if backupStorePeriod == period.PeriodForever {
|
||||
continue
|
||||
}
|
||||
|
||||
storeDuration := backupStorePeriod.ToDuration()
|
||||
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
|
||||
|
||||
oldBackups, err := c.backupRepository.FindBackupsBeforeDate(
|
||||
backupConfig.DatabaseID,
|
||||
dateBeforeBackupsShouldBeDeleted,
|
||||
)
|
||||
if err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to find old backups for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, backup := range oldBackups {
|
||||
if err := c.DeleteBackup(backup); err != nil {
|
||||
c.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
c.logger.Info(
|
||||
"Deleted old backup",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanExceededBackups() error {
|
||||
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
if backupConfig.MaxBackupsTotalSizeMB <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := c.cleanExceededBackupsForDatabase(
|
||||
backupConfig.DatabaseID,
|
||||
backupConfig.MaxBackupsTotalSizeMB,
|
||||
); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to clean exceeded backups for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanExceededBackupsForDatabase(
|
||||
databaseID uuid.UUID,
|
||||
limitperDbMB int64,
|
||||
) error {
|
||||
for {
|
||||
backupsTotalSizeMB, err := c.backupRepository.GetTotalSizeByDatabase(databaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if backupsTotalSizeMB <= float64(limitperDbMB) {
|
||||
break
|
||||
}
|
||||
|
||||
oldestBackups, err := c.backupRepository.FindOldestByDatabaseExcludingInProgress(
|
||||
databaseID,
|
||||
1,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(oldestBackups) == 0 {
|
||||
c.logger.Warn(
|
||||
"No backups to delete but still over limit",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"totalSizeMB",
|
||||
backupsTotalSizeMB,
|
||||
"limitMB",
|
||||
limitperDbMB,
|
||||
)
|
||||
break
|
||||
}
|
||||
|
||||
backup := oldestBackups[0]
|
||||
if err := c.DeleteBackup(backup); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to delete exceeded backup",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
c.logger.Info(
|
||||
"Deleted exceeded backup",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"backupSizeMB",
|
||||
backup.BackupSizeMb,
|
||||
"totalSizeMB",
|
||||
backupsTotalSizeMB,
|
||||
"limitMB",
|
||||
limitperDbMB,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,595 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/storage"
|
||||
"databasus-backend/internal/util/period"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_CleanOldBackups_DeletesBackupsOlderThanStorePeriod(t *testing.T) {
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Create backup interval
|
||||
interval := createTestInterval()
|
||||
|
||||
backupConfig := &backups_config.BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create backups with different ages
|
||||
now := time.Now().UTC()
|
||||
oldBackup1 := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10,
|
||||
CreatedAt: now.Add(-10 * 24 * time.Hour), // 10 days old
|
||||
}
|
||||
oldBackup2 := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10,
|
||||
CreatedAt: now.Add(-8 * 24 * time.Hour), // 8 days old
|
||||
}
|
||||
recentBackup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10,
|
||||
CreatedAt: now.Add(-3 * 24 * time.Hour), // 3 days old
|
||||
}
|
||||
|
||||
err = backupRepository.Save(oldBackup1)
|
||||
assert.NoError(t, err)
|
||||
err = backupRepository.Save(oldBackup2)
|
||||
assert.NoError(t, err)
|
||||
err = backupRepository.Save(recentBackup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Run cleanup
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanOldBackups()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify old backups deleted, recent backup remains
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(remainingBackups))
|
||||
assert.Equal(t, recentBackup.ID, remainingBackups[0].ID)
|
||||
}
|
||||
|
||||
func Test_CleanOldBackups_SkipsDatabaseWithForeverStorePeriod(t *testing.T) {
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Create backup interval
|
||||
interval := createTestInterval()
|
||||
|
||||
backupConfig := &backups_config.BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create very old backup
|
||||
oldBackup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10,
|
||||
CreatedAt: time.Now().UTC().Add(-365 * 24 * time.Hour), // 1 year old
|
||||
}
|
||||
err = backupRepository.Save(oldBackup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Run cleanup
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanOldBackups()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify backup still exists
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(remainingBackups))
|
||||
assert.Equal(t, oldBackup.ID, remainingBackups[0].ID)
|
||||
}
|
||||
|
||||
func Test_CleanExceededBackups_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Create backup interval
|
||||
interval := createTestInterval()
|
||||
|
||||
backupConfig := &backups_config.BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
MaxBackupsTotalSizeMB: 100, // 100 MB limit
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create 3 backups totaling 50MB (under limit)
|
||||
for i := 0; i < 3; i++ {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 16.67,
|
||||
CreatedAt: time.Now().UTC().Add(-time.Duration(i) * time.Hour),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Run cleanup
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanExceededBackups()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify all backups remain
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, len(remainingBackups))
|
||||
}
|
||||
|
||||
func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T) {
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Create backup interval
|
||||
interval := createTestInterval()
|
||||
|
||||
backupConfig := &backups_config.BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
MaxBackupsTotalSizeMB: 30, // 30 MB limit
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create 5 backups of 10MB each (total 50MB, over 30MB limit)
|
||||
now := time.Now().UTC()
|
||||
var backupIDs []uuid.UUID
|
||||
for i := 0; i < 5; i++ {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10,
|
||||
CreatedAt: now.Add(-time.Duration(4-i) * time.Hour), // Oldest first
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
backupIDs = append(backupIDs, backup.ID)
|
||||
}
|
||||
|
||||
// Run cleanup
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanExceededBackups()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify 2 oldest backups deleted, 3 newest remain
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, len(remainingBackups))
|
||||
|
||||
// Check that the newest 3 backups remain
|
||||
remainingIDs := make(map[uuid.UUID]bool)
|
||||
for _, backup := range remainingBackups {
|
||||
remainingIDs[backup.ID] = true
|
||||
}
|
||||
assert.False(t, remainingIDs[backupIDs[0]]) // Oldest deleted
|
||||
assert.False(t, remainingIDs[backupIDs[1]]) // 2nd oldest deleted
|
||||
assert.True(t, remainingIDs[backupIDs[2]]) // 3rd remains
|
||||
assert.True(t, remainingIDs[backupIDs[3]]) // 4th remains
|
||||
assert.True(t, remainingIDs[backupIDs[4]]) // Newest remains
|
||||
}
|
||||
|
||||
func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Create backup interval
|
||||
interval := createTestInterval()
|
||||
|
||||
backupConfig := &backups_config.BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
MaxBackupsTotalSizeMB: 50, // 50 MB limit
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
// Create 3 completed backups of 30MB each
|
||||
completedBackups := make([]*backups_core.Backup, 3)
|
||||
for i := 0; i < 3; i++ {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 30,
|
||||
CreatedAt: now.Add(-time.Duration(3-i) * time.Hour),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
completedBackups[i] = backup
|
||||
}
|
||||
|
||||
// Create 1 in-progress backup (should be excluded from size calculation and deletion)
|
||||
inProgressBackup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
BackupSizeMb: 10,
|
||||
CreatedAt: now,
|
||||
}
|
||||
err = backupRepository.Save(inProgressBackup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Run cleanup
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanExceededBackups()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify: only completed backups deleted, in-progress remains
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Should have in-progress + 1 completed (total 40MB completed + 10MB in-progress)
|
||||
assert.GreaterOrEqual(t, len(remainingBackups), 2)
|
||||
|
||||
// Verify in-progress backup still exists
|
||||
var inProgressFound bool
|
||||
for _, backup := range remainingBackups {
|
||||
if backup.ID == inProgressBackup.ID {
|
||||
inProgressFound = true
|
||||
assert.Equal(t, backups_core.BackupStatusInProgress, backup.Status)
|
||||
}
|
||||
}
|
||||
assert.True(t, inProgressFound, "In-progress backup should not be deleted")
|
||||
}
|
||||
|
||||
func Test_CleanExceededBackups_WithZeroLimit_SkipsDatabase(t *testing.T) {
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Create backup interval
|
||||
interval := createTestInterval()
|
||||
|
||||
backupConfig := &backups_config.BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
MaxBackupsTotalSizeMB: 0, // No size limit
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create large backups
|
||||
for i := 0; i < 10; i++ {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 100,
|
||||
CreatedAt: time.Now().UTC().Add(-time.Duration(i) * time.Hour),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Run cleanup
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanExceededBackups()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify all backups remain
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 10, len(remainingBackups))
|
||||
}
|
||||
|
||||
func Test_GetTotalSizeByDatabase_CalculatesCorrectly(t *testing.T) {
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Create completed backups
|
||||
completedBackup1 := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
completedBackup2 := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 20.3,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
// Create failed backup (should be included)
|
||||
failedBackup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusFailed,
|
||||
BackupSizeMb: 5.2,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
// Create in-progress backup (should be excluded)
|
||||
inProgressBackup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
BackupSizeMb: 100,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
err := backupRepository.Save(completedBackup1)
|
||||
assert.NoError(t, err)
|
||||
err = backupRepository.Save(completedBackup2)
|
||||
assert.NoError(t, err)
|
||||
err = backupRepository.Save(failedBackup)
|
||||
assert.NoError(t, err)
|
||||
err = backupRepository.Save(inProgressBackup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Calculate total size
|
||||
totalSize, err := backupRepository.GetTotalSizeByDatabase(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Should be 10.5 + 20.3 + 5.2 = 36.0 (excluding in-progress 100)
|
||||
assert.InDelta(t, 36.0, totalSize, 0.1)
|
||||
}
|
||||
|
||||
// Mock listener for testing
|
||||
type mockBackupRemoveListener struct {
|
||||
onBeforeBackupRemove func(*backups_core.Backup) error
|
||||
}
|
||||
|
||||
func (m *mockBackupRemoveListener) OnBeforeBackupRemove(backup *backups_core.Backup) error {
|
||||
if m.onBeforeBackupRemove != nil {
|
||||
return m.onBeforeBackupRemove(backup)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Test_DeleteBackup_WhenStorageDeleteFails_BackupStillRemovedFromDatabase verifies resilience
|
||||
// when storage becomes unavailable. Even if storage.DeleteFile fails (e.g., storage is offline,
|
||||
// credentials changed, or storage was deleted), the backup record should still be removed from
|
||||
// the database. This prevents orphaned backup records when storage is no longer accessible.
|
||||
func Test_DeleteBackup_WhenStorageDeleteFails_BackupStillRemovedFromDatabase(t *testing.T) {
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
testStorage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, testStorage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(testStorage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: testStorage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err := backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
|
||||
err = cleaner.DeleteBackup(backup)
|
||||
assert.NoError(t, err, "DeleteBackup should succeed even when storage file doesn't exist")
|
||||
|
||||
deletedBackup, err := backupRepository.FindByID(backup.ID)
|
||||
assert.Error(t, err, "Backup should not exist in database")
|
||||
assert.Nil(t, deletedBackup)
|
||||
}
|
||||
|
||||
func createTestInterval() *intervals.Interval {
|
||||
timeOfDay := "04:00"
|
||||
interval := &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
|
||||
err := storage.GetDb().Create(interval).Error
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return interval
|
||||
}
|
||||
98
backend/internal/features/backups/backups/backuping/di.go
Normal file
98
backend/internal/features/backups/backups/backuping/di.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/usecases"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var backupRepository = &backups_core.BackupRepository{}
|
||||
|
||||
var taskCancelManager = tasks_cancellation.GetTaskCancelManager()
|
||||
|
||||
var backupCleaner = &BackupCleaner{
|
||||
backupRepository,
|
||||
storages.GetStorageService(),
|
||||
backups_config.GetBackupConfigService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
logger.GetLogger(),
|
||||
[]backups_core.BackupRemoveListener{},
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
|
||||
var backupNodesRegistry = &BackupNodesRegistry{
|
||||
cache_utils.GetValkeyClient(),
|
||||
logger.GetLogger(),
|
||||
cache_utils.DefaultCacheTimeout,
|
||||
cache_utils.NewPubSubManager(),
|
||||
cache_utils.NewPubSubManager(),
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
|
||||
func getNodeID() uuid.UUID {
|
||||
return uuid.New()
|
||||
}
|
||||
|
||||
var backuperNode = &BackuperNode{
|
||||
databases.GetDatabaseService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
taskCancelManager,
|
||||
backupNodesRegistry,
|
||||
logger.GetLogger(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
getNodeID(),
|
||||
time.Time{},
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
|
||||
var backupsScheduler = &BackupsScheduler{
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
taskCancelManager,
|
||||
backupNodesRegistry,
|
||||
databases.GetDatabaseService(),
|
||||
time.Now().UTC(),
|
||||
logger.GetLogger(),
|
||||
make(map[uuid.UUID]BackupToNodeRelation),
|
||||
backuperNode,
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
|
||||
func GetBackupsScheduler() *BackupsScheduler {
|
||||
return backupsScheduler
|
||||
}
|
||||
|
||||
func GetBackuperNode() *BackuperNode {
|
||||
return backuperNode
|
||||
}
|
||||
|
||||
func GetBackupNodesRegistry() *BackupNodesRegistry {
|
||||
return backupNodesRegistry
|
||||
}
|
||||
|
||||
func GetBackupCleaner() *BackupCleaner {
|
||||
return backupCleaner
|
||||
}
|
||||
34
backend/internal/features/backups/backups/backuping/dto.go
Normal file
34
backend/internal/features/backups/backups/backuping/dto.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type BackupToNodeRelation struct {
|
||||
NodeID uuid.UUID `json:"nodeId"`
|
||||
BackupsIDs []uuid.UUID `json:"backupsIds"`
|
||||
}
|
||||
|
||||
type BackupNode struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
ThroughputMBs int `json:"throughputMBs"`
|
||||
LastHeartbeat time.Time `json:"lastHeartbeat"`
|
||||
}
|
||||
|
||||
type BackupNodeStats struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
ActiveBackups int `json:"activeBackups"`
|
||||
}
|
||||
|
||||
type BackupSubmitMessage struct {
|
||||
NodeID uuid.UUID `json:"nodeId"`
|
||||
BackupID uuid.UUID `json:"backupId"`
|
||||
IsCallNotifier bool `json:"isCallNotifier"`
|
||||
}
|
||||
|
||||
type BackupCompletionMessage struct {
|
||||
NodeID uuid.UUID `json:"nodeId"`
|
||||
BackupID uuid.UUID `json:"backupId"`
|
||||
}
|
||||
195
backend/internal/features/backups/backups/backuping/mocks.go
Normal file
195
backend/internal/features/backups/backups/backuping/mocks.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
common "databasus-backend/internal/features/backups/backups/common"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type MockNotificationSender struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockNotificationSender) SendNotification(
|
||||
notifier *notifiers.Notifier,
|
||||
title string,
|
||||
message string,
|
||||
) {
|
||||
m.Called(notifier, title, message)
|
||||
}
|
||||
|
||||
type CreateFailedBackupUsecase struct{}
|
||||
|
||||
func (uc *CreateFailedBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
backupProgressListener(10)
|
||||
return nil, errors.New("backup failed")
|
||||
}
|
||||
|
||||
type CreateSuccessBackupUsecase struct{}
|
||||
|
||||
func (uc *CreateSuccessBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
backupProgressListener(10)
|
||||
return &common.BackupMetadata{
|
||||
EncryptionSalt: nil,
|
||||
EncryptionIV: nil,
|
||||
Encryption: backups_config.BackupEncryptionNone,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateLargeBackupUsecase simulates a large backup (10000 MB)
|
||||
type CreateLargeBackupUsecase struct{}
|
||||
|
||||
func (uc *CreateLargeBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
backupProgressListener(10000)
|
||||
return &common.BackupMetadata{
|
||||
EncryptionSalt: nil,
|
||||
EncryptionIV: nil,
|
||||
Encryption: backups_config.BackupEncryptionNone,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateProgressiveBackupUsecase simulates progressive size updates that exceed limit
|
||||
type CreateProgressiveBackupUsecase struct{}
|
||||
|
||||
func (uc *CreateProgressiveBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
// Simulate progressive backup that grows beyond limit
|
||||
backupProgressListener(1)
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
backupProgressListener(3)
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
backupProgressListener(5)
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
backupProgressListener(10) // This exceeds the 5 MB limit
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
// Should not reach here due to cancellation
|
||||
return &common.BackupMetadata{
|
||||
EncryptionSalt: nil,
|
||||
EncryptionIV: nil,
|
||||
Encryption: backups_config.BackupEncryptionNone,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateMediumBackupUsecase simulates a 50 MB backup
|
||||
type CreateMediumBackupUsecase struct{}
|
||||
|
||||
func (uc *CreateMediumBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
backupProgressListener(50)
|
||||
return &common.BackupMetadata{
|
||||
EncryptionSalt: nil,
|
||||
EncryptionIV: nil,
|
||||
Encryption: backups_config.BackupEncryptionNone,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// MockTrackingBackupUsecase tracks backup use case calls for testing parallel execution
|
||||
type MockTrackingBackupUsecase struct {
|
||||
callCount atomic.Int32
|
||||
calledBackupIDs chan uuid.UUID
|
||||
}
|
||||
|
||||
func NewMockTrackingBackupUsecase() *MockTrackingBackupUsecase {
|
||||
return &MockTrackingBackupUsecase{
|
||||
calledBackupIDs: make(chan uuid.UUID, 10),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockTrackingBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
m.callCount.Add(1)
|
||||
|
||||
// Send backup ID to channel (non-blocking)
|
||||
select {
|
||||
case m.calledBackupIDs <- backup.ID:
|
||||
default:
|
||||
}
|
||||
|
||||
// Simulate backup work
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
backupProgressListener(10)
|
||||
|
||||
return &common.BackupMetadata{
|
||||
EncryptionSalt: nil,
|
||||
EncryptionIV: nil,
|
||||
Encryption: backups_config.BackupEncryptionNone,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MockTrackingBackupUsecase) GetCallCount() int32 {
|
||||
return m.callCount.Load()
|
||||
}
|
||||
|
||||
func (m *MockTrackingBackupUsecase) GetCalledBackupIDs() []uuid.UUID {
|
||||
ids := []uuid.UUID{}
|
||||
for {
|
||||
select {
|
||||
case id := <-m.calledBackupIDs:
|
||||
ids = append(ids, id)
|
||||
default:
|
||||
return ids
|
||||
}
|
||||
}
|
||||
}
|
||||
633
backend/internal/features/backups/backups/backuping/registry.go
Normal file
633
backend/internal/features/backups/backups/backuping/registry.go
Normal file
@@ -0,0 +1,633 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/valkey-io/valkey-go"
|
||||
)
|
||||
|
||||
const (
|
||||
nodeInfoKeyPrefix = "backup:node:"
|
||||
nodeInfoKeySuffix = ":info"
|
||||
nodeActiveBackupsPrefix = "backup:node:"
|
||||
nodeActiveBackupsSuffix = ":active_backups"
|
||||
backupSubmitChannel = "backup:submit"
|
||||
backupCompletionChannel = "backup:completion"
|
||||
|
||||
deadNodeThreshold = 2 * time.Minute
|
||||
cleanupTickerInterval = 1 * time.Second
|
||||
)
|
||||
|
||||
// BackupNodesRegistry helps to sync backups scheduler and backup nodes.
|
||||
//
|
||||
// Features:
|
||||
// - Track node availability and load level
|
||||
// - Assign from scheduler to node backups needed to be processed
|
||||
// - Notify scheduler from node about backup completion
|
||||
//
|
||||
// Important things to remember:
|
||||
// - Nodes without heartbeat for more than 2 minutes are not included
|
||||
// in available nodes list and stats
|
||||
//
|
||||
// Cleanup dead nodes performed on 2 levels:
|
||||
// - List and stats functions do not return dead nodes
|
||||
// - Periodically dead nodes are cleaned up in cache (to not
|
||||
// accumulate too many dead nodes in cache)
|
||||
type BackupNodesRegistry struct {
|
||||
client valkey.Client
|
||||
logger *slog.Logger
|
||||
timeout time.Duration
|
||||
pubsubBackups *cache_utils.PubSubManager
|
||||
pubsubCompletions *cache_utils.PubSubManager
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) Run(ctx context.Context) {
|
||||
wasAlreadyRun := r.hasRun.Load()
|
||||
|
||||
r.runOnce.Do(func() {
|
||||
r.hasRun.Store(true)
|
||||
|
||||
if err := r.cleanupDeadNodes(); err != nil {
|
||||
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(cleanupTickerInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := r.cleanupDeadNodes(); err != nil {
|
||||
r.logger.Error("Failed to cleanup dead nodes", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", r))
|
||||
}
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) GetAvailableNodes() ([]BackupNode, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
var allKeys []string
|
||||
cursor := uint64(0)
|
||||
pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix
|
||||
|
||||
for {
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return nil, fmt.Errorf("failed to scan node keys: %w", result.Error())
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse scan result: %w", err)
|
||||
}
|
||||
|
||||
allKeys = append(allKeys, scanResult.Elements...)
|
||||
|
||||
cursor = scanResult.Cursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(allKeys) == 0 {
|
||||
return []BackupNode{}, nil
|
||||
}
|
||||
|
||||
keyDataMap, err := r.pipelineGetKeys(allKeys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pipeline get node keys: %w", err)
|
||||
}
|
||||
|
||||
threshold := time.Now().UTC().Add(-deadNodeThreshold)
|
||||
var nodes []BackupNode
|
||||
|
||||
for key, data := range keyDataMap {
|
||||
// Skip if the key doesn't exist (data is empty)
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var node BackupNode
|
||||
if err := json.Unmarshal(data, &node); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal node data", "key", key, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip nodes with zero/uninitialized heartbeat
|
||||
if node.LastHeartbeat.IsZero() {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.LastHeartbeat.Before(threshold) {
|
||||
continue
|
||||
}
|
||||
|
||||
nodes = append(nodes, node)
|
||||
}
|
||||
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) GetBackupNodesStats() ([]BackupNodeStats, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
var allKeys []string
|
||||
cursor := uint64(0)
|
||||
pattern := nodeActiveBackupsPrefix + "*" + nodeActiveBackupsSuffix
|
||||
|
||||
for {
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(100).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return nil, fmt.Errorf("failed to scan active backups keys: %w", result.Error())
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse scan result: %w", err)
|
||||
}
|
||||
|
||||
allKeys = append(allKeys, scanResult.Elements...)
|
||||
|
||||
cursor = scanResult.Cursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(allKeys) == 0 {
|
||||
return []BackupNodeStats{}, nil
|
||||
}
|
||||
|
||||
keyDataMap, err := r.pipelineGetKeys(allKeys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pipeline get active backups keys: %w", err)
|
||||
}
|
||||
|
||||
var nodeInfoKeys []string
|
||||
nodeIDToStatsKey := make(map[string]string)
|
||||
for key := range keyDataMap {
|
||||
nodeID := r.extractNodeIDFromKey(key, nodeActiveBackupsPrefix, nodeActiveBackupsSuffix)
|
||||
nodeIDStr := nodeID.String()
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeIDStr, nodeInfoKeySuffix)
|
||||
nodeInfoKeys = append(nodeInfoKeys, infoKey)
|
||||
nodeIDToStatsKey[infoKey] = key
|
||||
}
|
||||
|
||||
nodeInfoMap, err := r.pipelineGetKeys(nodeInfoKeys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pipeline get node info keys: %w", err)
|
||||
}
|
||||
|
||||
threshold := time.Now().UTC().Add(-deadNodeThreshold)
|
||||
var stats []BackupNodeStats
|
||||
for infoKey, nodeData := range nodeInfoMap {
|
||||
// Skip if the info key doesn't exist (nodeData is empty)
|
||||
if len(nodeData) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var node BackupNode
|
||||
if err := json.Unmarshal(nodeData, &node); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal node data", "key", infoKey, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip nodes with zero/uninitialized heartbeat
|
||||
if node.LastHeartbeat.IsZero() {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.LastHeartbeat.Before(threshold) {
|
||||
continue
|
||||
}
|
||||
|
||||
statsKey := nodeIDToStatsKey[infoKey]
|
||||
tasksData := keyDataMap[statsKey]
|
||||
count, err := r.parseIntFromBytes(tasksData)
|
||||
if err != nil {
|
||||
r.logger.Warn("Failed to parse active backups count", "key", statsKey, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
stat := BackupNodeStats{
|
||||
ID: node.ID,
|
||||
ActiveBackups: int(count),
|
||||
}
|
||||
stats = append(stats, stat)
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) IncrementBackupsInProgress(nodeID uuid.UUID) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID.String(), nodeActiveBackupsSuffix)
|
||||
result := r.client.Do(ctx, r.client.B().Incr().Key(key).Build())
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to increment backups in progress for node %s: %w",
|
||||
nodeID,
|
||||
result.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) DecrementBackupsInProgress(nodeID uuid.UUID) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID.String(), nodeActiveBackupsSuffix)
|
||||
result := r.client.Do(ctx, r.client.B().Decr().Key(key).Build())
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to decrement backups in progress for node %s: %w",
|
||||
nodeID,
|
||||
result.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
newValue, err := result.AsInt64()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse decremented value for node %s: %w", nodeID, err)
|
||||
}
|
||||
|
||||
if newValue < 0 {
|
||||
setCtx, setCancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
r.client.Do(setCtx, r.client.B().Set().Key(key).Value("0").Build())
|
||||
setCancel()
|
||||
r.logger.Warn("Active backups counter went below 0, reset to 0", "nodeID", nodeID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) HearthbeatNodeInRegistry(now time.Time, backupNode BackupNode) error {
|
||||
if now.IsZero() {
|
||||
return fmt.Errorf("cannot register node with zero heartbeat timestamp")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
backupNode.LastHeartbeat = now
|
||||
|
||||
data, err := json.Marshal(backupNode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal backup node: %w", err)
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, backupNode.ID.String(), nodeInfoKeySuffix)
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Set().Key(key).Value(string(data)).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to register node %s: %w", backupNode.ID, result.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) UnregisterNodeFromRegistry(backupNode BackupNode) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, backupNode.ID.String(), nodeInfoKeySuffix)
|
||||
counterKey := fmt.Sprintf(
|
||||
"%s%s%s",
|
||||
nodeActiveBackupsPrefix,
|
||||
backupNode.ID.String(),
|
||||
nodeActiveBackupsSuffix,
|
||||
)
|
||||
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Del().Key(infoKey, counterKey).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to unregister node %s: %w", backupNode.ID, result.Error())
|
||||
}
|
||||
|
||||
r.logger.Info("Unregistered node from registry", "nodeID", backupNode.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) AssignBackupToNode(
|
||||
targetNodeID uuid.UUID,
|
||||
backupID uuid.UUID,
|
||||
isCallNotifier bool,
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
message := BackupSubmitMessage{
|
||||
NodeID: targetNodeID,
|
||||
BackupID: backupID,
|
||||
IsCallNotifier: isCallNotifier,
|
||||
}
|
||||
|
||||
messageJSON, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal backup submit message: %w", err)
|
||||
}
|
||||
|
||||
err = r.pubsubBackups.Publish(ctx, backupSubmitChannel, string(messageJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to publish backup submit message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) SubscribeNodeForBackupsAssignment(
|
||||
nodeID uuid.UUID,
|
||||
handler func(backupID uuid.UUID, isCallNotifier bool),
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
wrappedHandler := func(message string) {
|
||||
var msg BackupSubmitMessage
|
||||
if err := json.Unmarshal([]byte(message), &msg); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal backup submit message", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if msg.NodeID != nodeID {
|
||||
return
|
||||
}
|
||||
|
||||
handler(msg.BackupID, msg.IsCallNotifier)
|
||||
}
|
||||
|
||||
err := r.pubsubBackups.Subscribe(ctx, backupSubmitChannel, wrappedHandler)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to backup submit channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Subscribed to backup submit channel", "nodeID", nodeID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) UnsubscribeNodeForBackupsAssignments() error {
|
||||
err := r.pubsubBackups.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unsubscribe from backup submit channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Unsubscribed from backup submit channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) PublishBackupCompletion(nodeID uuid.UUID, backupID uuid.UUID) error {
|
||||
ctx := context.Background()
|
||||
|
||||
message := BackupCompletionMessage{
|
||||
NodeID: nodeID,
|
||||
BackupID: backupID,
|
||||
}
|
||||
|
||||
messageJSON, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal backup completion message: %w", err)
|
||||
}
|
||||
|
||||
err = r.pubsubCompletions.Publish(ctx, backupCompletionChannel, string(messageJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to publish backup completion message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) SubscribeForBackupsCompletions(
|
||||
handler func(nodeID uuid.UUID, backupID uuid.UUID),
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
wrappedHandler := func(message string) {
|
||||
var msg BackupCompletionMessage
|
||||
if err := json.Unmarshal([]byte(message), &msg); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal backup completion message", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
handler(msg.NodeID, msg.BackupID)
|
||||
}
|
||||
|
||||
err := r.pubsubCompletions.Subscribe(ctx, backupCompletionChannel, wrappedHandler)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to backup completion channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Subscribed to backup completion channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) UnsubscribeForBackupsCompletions() error {
|
||||
err := r.pubsubCompletions.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unsubscribe from backup completion channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Unsubscribed from backup completion channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) extractNodeIDFromKey(key, prefix, suffix string) uuid.UUID {
|
||||
nodeIDStr := strings.TrimPrefix(key, prefix)
|
||||
nodeIDStr = strings.TrimSuffix(nodeIDStr, suffix)
|
||||
|
||||
nodeID, err := uuid.Parse(nodeIDStr)
|
||||
if err != nil {
|
||||
r.logger.Warn("Failed to parse node ID from key", "key", key, "error", err)
|
||||
return uuid.Nil
|
||||
}
|
||||
|
||||
return nodeID
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) pipelineGetKeys(keys []string) (map[string][]byte, error) {
|
||||
if len(keys) == 0 {
|
||||
return make(map[string][]byte), nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
commands := make([]valkey.Completed, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
commands = append(commands, r.client.B().Get().Key(key).Build())
|
||||
}
|
||||
|
||||
results := r.client.DoMulti(ctx, commands...)
|
||||
|
||||
keyDataMap := make(map[string][]byte, len(keys))
|
||||
for i, result := range results {
|
||||
if result.Error() != nil {
|
||||
r.logger.Warn("Failed to get key in pipeline", "key", keys[i], "error", result.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
data, err := result.AsBytes()
|
||||
if err != nil {
|
||||
r.logger.Warn("Failed to parse key data in pipeline", "key", keys[i], "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
keyDataMap[keys[i]] = data
|
||||
}
|
||||
|
||||
return keyDataMap, nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) parseIntFromBytes(data []byte) (int64, error) {
|
||||
str := string(data)
|
||||
var count int64
|
||||
_, err := fmt.Sscanf(str, "%d", &count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to parse integer from bytes: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) cleanupDeadNodes() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
var allKeys []string
|
||||
cursor := uint64(0)
|
||||
pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix
|
||||
|
||||
for {
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to scan node keys: %w", result.Error())
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse scan result: %w", err)
|
||||
}
|
||||
|
||||
allKeys = append(allKeys, scanResult.Elements...)
|
||||
|
||||
cursor = scanResult.Cursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(allKeys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
keyDataMap, err := r.pipelineGetKeys(allKeys)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to pipeline get node keys: %w", err)
|
||||
}
|
||||
|
||||
threshold := time.Now().UTC().Add(-deadNodeThreshold)
|
||||
var deadNodeKeys []string
|
||||
|
||||
for key, data := range keyDataMap {
|
||||
// Skip if the key doesn't exist (data is empty)
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var node BackupNode
|
||||
if err := json.Unmarshal(data, &node); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal node data during cleanup", "key", key, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip nodes with zero/uninitialized heartbeat
|
||||
if node.LastHeartbeat.IsZero() {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.LastHeartbeat.Before(threshold) {
|
||||
nodeID := node.ID.String()
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeID, nodeInfoKeySuffix)
|
||||
statsKey := fmt.Sprintf(
|
||||
"%s%s%s",
|
||||
nodeActiveBackupsPrefix,
|
||||
nodeID,
|
||||
nodeActiveBackupsSuffix,
|
||||
)
|
||||
|
||||
deadNodeKeys = append(deadNodeKeys, infoKey, statsKey)
|
||||
r.logger.Info(
|
||||
"Marking node for cleanup",
|
||||
"nodeID", nodeID,
|
||||
"lastHeartbeat", node.LastHeartbeat,
|
||||
"threshold", threshold,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if len(deadNodeKeys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
delCtx, delCancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer delCancel()
|
||||
|
||||
result := r.client.Do(
|
||||
delCtx,
|
||||
r.client.B().Del().Key(deadNodeKeys...).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to delete dead node keys: %w", result.Error())
|
||||
}
|
||||
|
||||
deletedCount, err := result.AsInt64()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse deleted count: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Cleaned up dead nodes", "deletedKeysCount", deletedCount)
|
||||
return nil
|
||||
}
|
||||
1134
backend/internal/features/backups/backups/backuping/registry_test.go
Normal file
1134
backend/internal/features/backups/backups/backuping/registry_test.go
Normal file
File diff suppressed because it is too large
Load Diff
592
backend/internal/features/backups/backups/backuping/scheduler.go
Normal file
592
backend/internal/features/backups/backups/backuping/scheduler.go
Normal file
@@ -0,0 +1,592 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
files_utils "databasus-backend/internal/util/files"
|
||||
)
|
||||
|
||||
const (
|
||||
schedulerStartupDelay = 1 * time.Minute
|
||||
schedulerTickerInterval = 1 * time.Minute
|
||||
schedulerHealthcheckThreshold = 5 * time.Minute
|
||||
)
|
||||
|
||||
type BackupsScheduler struct {
|
||||
backupRepository *backups_core.BackupRepository
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
taskCancelManager *task_cancellation.TaskCancelManager
|
||||
backupNodesRegistry *BackupNodesRegistry
|
||||
databaseService *databases.DatabaseService
|
||||
|
||||
lastBackupTime time.Time
|
||||
logger *slog.Logger
|
||||
|
||||
backupToNodeRelations map[uuid.UUID]BackupToNodeRelation
|
||||
backuperNode *BackuperNode
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) Run(ctx context.Context) {
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
|
||||
if config.GetEnv().IsManyNodesMode {
|
||||
// wait other nodes to start
|
||||
time.Sleep(schedulerStartupDelay)
|
||||
}
|
||||
|
||||
if err := s.failBackupsInProgress(); err != nil {
|
||||
s.logger.Error("Failed to fail backups in progress", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err := s.backupNodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to subscribe to backup completions", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := s.backupNodesRegistry.UnsubscribeForBackupsCompletions(); err != nil {
|
||||
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(schedulerTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.checkDeadNodesAndFailBackups(); err != nil {
|
||||
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
|
||||
}
|
||||
|
||||
if err := s.runPendingBackups(); err != nil {
|
||||
s.logger.Error("Failed to run pending backups", "error", err)
|
||||
}
|
||||
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) IsSchedulerRunning() bool {
|
||||
// if last backup time is more than 5 minutes ago, return false
|
||||
return s.lastBackupTime.After(time.Now().UTC().Add(-schedulerHealthcheckThreshold))
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) IsBackupNodesAvailable() bool {
|
||||
nodes, err := s.backupNodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get available nodes for health check", "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
return len(nodes) > 0
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) StartBackup(database *databases.Database, isCallNotifier bool) {
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if backupConfig.StorageID == nil {
|
||||
s.logger.Error("Backup config storage ID is nil", "databaseId", database.ID)
|
||||
return
|
||||
}
|
||||
|
||||
// Check for existing in-progress backups
|
||||
inProgressBackups, err := s.backupRepository.FindByDatabaseIdAndStatus(
|
||||
database.ID,
|
||||
backups_core.BackupStatusInProgress,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to check for in-progress backups",
|
||||
"databaseId",
|
||||
database.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if len(inProgressBackups) > 0 {
|
||||
s.logger.Warn(
|
||||
"Backup already in progress for database, skipping new backup",
|
||||
"databaseId",
|
||||
database.ID,
|
||||
"existingBackupId",
|
||||
inProgressBackups[0].ID,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
leastBusyNodeID, err := s.calculateLeastBusyNode()
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to calculate least busy node",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
backupID := uuid.New()
|
||||
timestamp := time.Now().UTC()
|
||||
|
||||
backup := &backups_core.Backup{
|
||||
ID: backupID,
|
||||
FileName: fmt.Sprintf(
|
||||
"%s-%s-%s",
|
||||
files_utils.SanitizeFilename(database.Name),
|
||||
timestamp.Format("20060102-150405"),
|
||||
backupID.String(),
|
||||
),
|
||||
DatabaseID: backupConfig.DatabaseID,
|
||||
StorageID: *backupConfig.StorageID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
BackupSizeMb: 0,
|
||||
CreatedAt: timestamp,
|
||||
}
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to save backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.backupNodesRegistry.IncrementBackupsInProgress(*leastBusyNodeID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to increment backups in progress",
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.backupNodesRegistry.AssignBackupToNode(*leastBusyNodeID, backup.ID, isCallNotifier); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to submit backup",
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
if decrementErr := s.backupNodesRegistry.DecrementBackupsInProgress(*leastBusyNodeID); decrementErr != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement backups in progress after submit failure",
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
"error",
|
||||
decrementErr,
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if relation, exists := s.backupToNodeRelations[*leastBusyNodeID]; exists {
|
||||
relation.BackupsIDs = append(relation.BackupsIDs, backup.ID)
|
||||
s.backupToNodeRelations[*leastBusyNodeID] = relation
|
||||
} else {
|
||||
s.backupToNodeRelations[*leastBusyNodeID] = BackupToNodeRelation{
|
||||
*leastBusyNodeID,
|
||||
[]uuid.UUID{backup.ID},
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Info(
|
||||
"Successfully triggered scheduled backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
)
|
||||
}
|
||||
|
||||
// GetRemainedBackupTryCount returns the number of remaining backup tries for a given backup.
|
||||
// If the backup is not failed or the backup config does not allow retries, it returns 0.
|
||||
// If the backup is failed and the backup config allows retries, it returns the number of remaining tries.
|
||||
// If the backup is failed and the backup config does not allow retries, it returns 0.
|
||||
func (s *BackupsScheduler) GetRemainedBackupTryCount(lastBackup *backups_core.Backup) int {
|
||||
if lastBackup == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
if lastBackup.Status != backups_core.BackupStatusFailed {
|
||||
return 0
|
||||
}
|
||||
|
||||
if lastBackup.IsSkipRetry {
|
||||
return 0
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(lastBackup.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
if !backupConfig.IsRetryIfFailed {
|
||||
return 0
|
||||
}
|
||||
|
||||
maxFailedTriesCount := backupConfig.MaxFailedTriesCount
|
||||
|
||||
lastBackups, err := s.backupRepository.FindByDatabaseIDWithLimit(
|
||||
lastBackup.DatabaseID,
|
||||
maxFailedTriesCount,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to find last backups by database ID", "error", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
lastFailedBackups := make([]*backups_core.Backup, 0)
|
||||
|
||||
for _, backup := range lastBackups {
|
||||
if backup.Status == backups_core.BackupStatusFailed {
|
||||
lastFailedBackups = append(lastFailedBackups, backup)
|
||||
}
|
||||
}
|
||||
|
||||
return maxFailedTriesCount - len(lastFailedBackups)
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) runPendingBackups() error {
|
||||
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
if backupConfig.BackupInterval == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
lastBackup, err := s.backupRepository.FindLastByDatabaseID(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to get last backup for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
var lastBackupTime *time.Time
|
||||
if lastBackup != nil {
|
||||
lastBackupTime = &lastBackup.CreatedAt
|
||||
}
|
||||
|
||||
remainedBackupTryCount := s.GetRemainedBackupTryCount(lastBackup)
|
||||
|
||||
if backupConfig.BackupInterval.ShouldTriggerBackup(time.Now().UTC(), lastBackupTime) ||
|
||||
remainedBackupTryCount > 0 {
|
||||
s.logger.Info(
|
||||
"Triggering scheduled backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"intervalType",
|
||||
backupConfig.BackupInterval.Interval,
|
||||
)
|
||||
|
||||
database, err := s.databaseService.GetDatabaseByID(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get database by ID", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
s.StartBackup(database, remainedBackupTryCount == 1)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) failBackupsInProgress() error {
|
||||
backupsInProgress, err := s.backupRepository.FindByStatus(backups_core.BackupStatusInProgress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backup := range backupsInProgress {
|
||||
if err := s.taskCancelManager.CancelTask(backup.ID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to cancel backup via task cancel manager",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(backup.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
failMessage := "Backup failed due to application restart"
|
||||
backup.FailMessage = &failMessage
|
||||
backup.Status = backups_core.BackupStatusFailed
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
s.backuperNode.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupFailed,
|
||||
&failMessage,
|
||||
)
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
|
||||
nodes, err := s.backupNodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get available nodes: %w", err)
|
||||
}
|
||||
|
||||
if len(nodes) == 0 {
|
||||
return nil, fmt.Errorf("no nodes available")
|
||||
}
|
||||
|
||||
stats, err := s.backupNodesRegistry.GetBackupNodesStats()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get backup nodes stats: %w", err)
|
||||
}
|
||||
|
||||
statsMap := make(map[uuid.UUID]int)
|
||||
for _, stat := range stats {
|
||||
statsMap[stat.ID] = stat.ActiveBackups
|
||||
}
|
||||
|
||||
var bestNode *BackupNode
|
||||
var bestScore float64 = -1
|
||||
|
||||
for i := range nodes {
|
||||
node := &nodes[i]
|
||||
|
||||
activeBackups := statsMap[node.ID]
|
||||
|
||||
var score float64
|
||||
if node.ThroughputMBs > 0 {
|
||||
score = float64(activeBackups) / float64(node.ThroughputMBs)
|
||||
} else {
|
||||
score = float64(activeBackups) * 1000
|
||||
}
|
||||
|
||||
if bestNode == nil || score < bestScore {
|
||||
bestNode = node
|
||||
bestScore = score
|
||||
}
|
||||
}
|
||||
|
||||
if bestNode == nil {
|
||||
return nil, fmt.Errorf("no suitable nodes available")
|
||||
}
|
||||
|
||||
return &bestNode.ID, nil
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) onBackupCompleted(nodeID uuid.UUID, backupID uuid.UUID) {
|
||||
// Verify this task is actually a backup (registry contains multiple task types)
|
||||
_, err := s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
// Not a backup task, ignore it
|
||||
return
|
||||
}
|
||||
|
||||
relation, exists := s.backupToNodeRelations[nodeID]
|
||||
if !exists {
|
||||
s.logger.Warn(
|
||||
"Received completion for unknown node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
newBackupIDs := make([]uuid.UUID, 0)
|
||||
found := false
|
||||
for _, id := range relation.BackupsIDs {
|
||||
if id == backupID {
|
||||
found = true
|
||||
continue
|
||||
}
|
||||
newBackupIDs = append(newBackupIDs, id)
|
||||
}
|
||||
|
||||
if !found {
|
||||
s.logger.Warn(
|
||||
"Backup not found in node's backup list",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if len(newBackupIDs) == 0 {
|
||||
delete(s.backupToNodeRelations, nodeID)
|
||||
} else {
|
||||
relation.BackupsIDs = newBackupIDs
|
||||
s.backupToNodeRelations[nodeID] = relation
|
||||
}
|
||||
|
||||
if err := s.backupNodesRegistry.DecrementBackupsInProgress(nodeID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement backups in progress",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) checkDeadNodesAndFailBackups() error {
|
||||
nodes, err := s.backupNodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get available nodes: %w", err)
|
||||
}
|
||||
|
||||
aliveNodeIDs := make(map[uuid.UUID]bool)
|
||||
for _, node := range nodes {
|
||||
aliveNodeIDs[node.ID] = true
|
||||
}
|
||||
|
||||
for nodeID, relation := range s.backupToNodeRelations {
|
||||
if aliveNodeIDs[nodeID] {
|
||||
continue
|
||||
}
|
||||
|
||||
s.logger.Warn(
|
||||
"Node is dead, failing its backups",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupCount",
|
||||
len(relation.BackupsIDs),
|
||||
)
|
||||
|
||||
for _, backupID := range relation.BackupsIDs {
|
||||
backup, err := s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to find backup for dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
failMessage := "Backup failed due to node unavailability"
|
||||
backup.FailMessage = &failMessage
|
||||
backup.Status = backups_core.BackupStatusFailed
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to save failed backup for dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := s.backupNodesRegistry.DecrementBackupsInProgress(nodeID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement backups in progress for dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
s.logger.Info(
|
||||
"Failed backup due to dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
)
|
||||
}
|
||||
|
||||
delete(s.backupToNodeRelations, nodeID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
318
backend/internal/features/backups/backups/backuping/testing.go
Normal file
318
backend/internal/features/backups/backups/backuping/testing.go
Normal file
@@ -0,0 +1,318 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/usecases"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func CreateTestRouter() *gin.Engine {
|
||||
router := workspaces_testing.CreateTestRouter(
|
||||
workspaces_controllers.GetWorkspaceController(),
|
||||
workspaces_controllers.GetMembershipController(),
|
||||
databases.GetDatabaseController(),
|
||||
backups_config.GetBackupConfigController(),
|
||||
)
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
func CreateTestBackuperNode() *BackuperNode {
|
||||
return &BackuperNode{
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
workspaceService: workspaces_services.GetWorkspaceService(),
|
||||
backupRepository: backupRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
notificationSender: notifiers.GetNotifierService(),
|
||||
backupCancelManager: taskCancelManager,
|
||||
backupNodesRegistry: backupNodesRegistry,
|
||||
logger: logger.GetLogger(),
|
||||
createBackupUseCase: usecases.GetCreateBackupUsecase(),
|
||||
nodeID: uuid.New(),
|
||||
lastHeartbeat: time.Time{},
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateTestBackuperNodeWithUseCase(useCase backups_core.CreateBackupUsecase) *BackuperNode {
|
||||
return &BackuperNode{
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
workspaceService: workspaces_services.GetWorkspaceService(),
|
||||
backupRepository: backupRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
notificationSender: notifiers.GetNotifierService(),
|
||||
backupCancelManager: taskCancelManager,
|
||||
backupNodesRegistry: backupNodesRegistry,
|
||||
logger: logger.GetLogger(),
|
||||
createBackupUseCase: useCase,
|
||||
nodeID: uuid.New(),
|
||||
lastHeartbeat: time.Time{},
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateTestScheduler() *BackupsScheduler {
|
||||
return &BackupsScheduler{
|
||||
backupRepository: backupRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
taskCancelManager: taskCancelManager,
|
||||
backupNodesRegistry: backupNodesRegistry,
|
||||
lastBackupTime: time.Now().UTC(),
|
||||
logger: logger.GetLogger(),
|
||||
backupToNodeRelations: make(map[uuid.UUID]BackupToNodeRelation),
|
||||
backuperNode: CreateTestBackuperNode(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
// WaitForBackupCompletion waits for a new backup to be created and completed (or failed)
|
||||
// for the given database. It checks for backups with count greater than expectedInitialCount.
|
||||
func WaitForBackupCompletion(
|
||||
t *testing.T,
|
||||
databaseID uuid.UUID,
|
||||
expectedInitialCount int,
|
||||
timeout time.Duration,
|
||||
) {
|
||||
deadline := time.Now().UTC().Add(timeout)
|
||||
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
backups, err := backupRepository.FindByDatabaseID(databaseID)
|
||||
if err != nil {
|
||||
t.Logf("WaitForBackupCompletion: error finding backups: %v", err)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
t.Logf(
|
||||
"WaitForBackupCompletion: found %d backups (expected > %d)",
|
||||
len(backups),
|
||||
expectedInitialCount,
|
||||
)
|
||||
|
||||
if len(backups) > expectedInitialCount {
|
||||
// Check if the newest backup has completed or failed
|
||||
newestBackup := backups[0]
|
||||
t.Logf("WaitForBackupCompletion: newest backup status: %s", newestBackup.Status)
|
||||
|
||||
if newestBackup.Status == backups_core.BackupStatusCompleted ||
|
||||
newestBackup.Status == backups_core.BackupStatusFailed ||
|
||||
newestBackup.Status == backups_core.BackupStatusCanceled {
|
||||
t.Logf(
|
||||
"WaitForBackupCompletion: backup finished with status %s",
|
||||
newestBackup.Status,
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Logf("WaitForBackupCompletion: timeout waiting for backup to complete")
|
||||
}
|
||||
|
||||
// StartBackuperNodeForTest starts a BackuperNode in a goroutine for testing.
|
||||
// The node registers itself in the registry and subscribes to backup assignments.
|
||||
// Returns a context cancel function that should be deferred to stop the node.
|
||||
func StartBackuperNodeForTest(t *testing.T, backuperNode *BackuperNode) context.CancelFunc {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
backuperNode.Run(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Poll registry for node presence instead of fixed sleep
|
||||
deadline := time.Now().UTC().Add(5 * time.Second)
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
nodes, err := backupNodesRegistry.GetAvailableNodes()
|
||||
if err == nil {
|
||||
for _, node := range nodes {
|
||||
if node.ID == backuperNode.nodeID {
|
||||
t.Logf("BackuperNode registered in registry: %s", backuperNode.nodeID)
|
||||
|
||||
return func() {
|
||||
cancel()
|
||||
select {
|
||||
case <-done:
|
||||
t.Log("BackuperNode stopped gracefully")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Log("BackuperNode stop timeout")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Fatalf("BackuperNode failed to register in registry within timeout")
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartSchedulerForTest starts the BackupsScheduler in a goroutine for testing.
|
||||
// The scheduler subscribes to task completions and manages backup lifecycle.
|
||||
// Returns a context cancel function that should be deferred to stop the scheduler.
|
||||
func StartSchedulerForTest(t *testing.T, scheduler *BackupsScheduler) context.CancelFunc {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
scheduler.Run(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Give scheduler time to subscribe to completions
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
t.Log("BackupsScheduler started")
|
||||
|
||||
return func() {
|
||||
cancel()
|
||||
select {
|
||||
case <-done:
|
||||
t.Log("BackupsScheduler stopped gracefully")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Log("BackupsScheduler stop timeout")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StopBackuperNodeForTest stops the BackuperNode by canceling its context.
|
||||
// It waits for the node to unregister from the registry.
|
||||
func StopBackuperNodeForTest(t *testing.T, cancel context.CancelFunc, backuperNode *BackuperNode) {
|
||||
cancel()
|
||||
|
||||
// Wait for node to unregister from registry
|
||||
deadline := time.Now().UTC().Add(2 * time.Second)
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
nodes, err := backupNodesRegistry.GetAvailableNodes()
|
||||
if err == nil {
|
||||
found := false
|
||||
for _, node := range nodes {
|
||||
if node.ID == backuperNode.nodeID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Logf("BackuperNode unregistered from registry: %s", backuperNode.nodeID)
|
||||
return
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Logf("BackuperNode stop completed for %s", backuperNode.nodeID)
|
||||
}
|
||||
|
||||
func CreateMockNodeInRegistry(nodeID uuid.UUID, throughputMBs int, lastHeartbeat time.Time) error {
|
||||
backupNode := BackupNode{
|
||||
ID: nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: lastHeartbeat,
|
||||
}
|
||||
|
||||
return backupNodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
|
||||
}
|
||||
|
||||
func UpdateNodeHeartbeatDirectly(
|
||||
nodeID uuid.UUID,
|
||||
throughputMBs int,
|
||||
lastHeartbeat time.Time,
|
||||
) error {
|
||||
backupNode := BackupNode{
|
||||
ID: nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: lastHeartbeat,
|
||||
}
|
||||
|
||||
return backupNodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
|
||||
}
|
||||
|
||||
func GetNodeFromRegistry(nodeID uuid.UUID) (*BackupNode, error) {
|
||||
nodes, err := backupNodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
if node.ID == nodeID {
|
||||
return &node, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("node not found")
|
||||
}
|
||||
|
||||
// WaitForActiveTasksDecrease waits for the active task count to decrease below the initial count.
|
||||
// It polls the registry every 500ms until the count decreases or the timeout is reached.
|
||||
// Returns true if the count decreased, false if timeout was reached.
|
||||
func WaitForActiveTasksDecrease(
|
||||
t *testing.T,
|
||||
nodeID uuid.UUID,
|
||||
initialCount int,
|
||||
timeout time.Duration,
|
||||
) bool {
|
||||
deadline := time.Now().UTC().Add(timeout)
|
||||
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
stats, err := backupNodesRegistry.GetBackupNodesStats()
|
||||
if err != nil {
|
||||
t.Logf("WaitForActiveTasksDecrease: error getting node stats: %v", err)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, stat := range stats {
|
||||
if stat.ID == nodeID {
|
||||
t.Logf(
|
||||
"WaitForActiveTasksDecrease: current active tasks = %d (initial = %d)",
|
||||
stat.ActiveBackups,
|
||||
initialCount,
|
||||
)
|
||||
if stat.ActiveBackups < initialCount {
|
||||
t.Logf(
|
||||
"WaitForActiveTasksDecrease: active tasks decreased from %d to %d",
|
||||
initialCount,
|
||||
stat.ActiveBackups,
|
||||
)
|
||||
return true
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Logf("WaitForActiveTasksDecrease: timeout waiting for active tasks to decrease")
|
||||
return false
|
||||
}
|
||||
@@ -1,11 +1,16 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"context"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
"databasus-backend/internal/features/databases"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
files_utils "databasus-backend/internal/util/files"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -170,9 +175,10 @@ func (c *BackupController) CancelBackup(ctx *gin.Context) {
|
||||
// @Description Generate a token for downloading a backup file (valid for 5 minutes)
|
||||
// @Tags backups
|
||||
// @Param id path string true "Backup ID"
|
||||
// @Success 200 {object} GenerateDownloadTokenResponse
|
||||
// @Success 200 {object} backups_download.GenerateDownloadTokenResponse
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 409 {object} map[string]string "Download already in progress"
|
||||
// @Router /backups/{id}/download-token [post]
|
||||
func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
@@ -189,6 +195,15 @@ func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) {
|
||||
|
||||
response, err := c.backupService.GenerateDownloadToken(user, id)
|
||||
if err != nil {
|
||||
if err == backups_download.ErrDownloadAlreadyInProgress {
|
||||
ctx.JSON(
|
||||
http.StatusConflict,
|
||||
gin.H{
|
||||
"error": "Download already in progress for some of backups. Please wait until previous download completed or cancel it",
|
||||
},
|
||||
)
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -198,14 +213,22 @@ func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) {
|
||||
|
||||
// GetFile
|
||||
// @Summary Download a backup file
|
||||
// @Description Download the backup file for the specified backup using a download token
|
||||
// @Description Download the backup file for the specified backup using a download token.
|
||||
// @Description
|
||||
// @Description **Download Concurrency Control:**
|
||||
// @Description - Only one download per user is allowed at a time
|
||||
// @Description - If a download is already in progress, returns 409 Conflict
|
||||
// @Description - Downloads are tracked using cache with 5-second TTL and 3-second heartbeat
|
||||
// @Description - Browser cancellations automatically release the download lock
|
||||
// @Description - Server crashes are handled via automatic cache expiry (5 seconds)
|
||||
// @Tags backups
|
||||
// @Param id path string true "Backup ID"
|
||||
// @Param token query string true "Download token"
|
||||
// @Success 200 {file} file
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 500
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 409 {object} map[string]string "Download already in progress"
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /backups/{id}/file [get]
|
||||
func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
token := ctx.Query("token")
|
||||
@@ -214,7 +237,6 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Get backup ID from URL
|
||||
backupIDParam := ctx.Param("id")
|
||||
backupID, err := uuid.Parse(backupIDParam)
|
||||
if err != nil {
|
||||
@@ -222,13 +244,22 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
downloadToken, err := c.backupService.ValidateDownloadToken(token)
|
||||
downloadToken, rateLimiter, err := c.backupService.ValidateDownloadToken(token)
|
||||
if err != nil {
|
||||
if err == backups_download.ErrDownloadAlreadyInProgress {
|
||||
ctx.JSON(
|
||||
http.StatusConflict,
|
||||
gin.H{
|
||||
"error": "download already in progress for this user. Please wait until previous download completed or cancel it",
|
||||
},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired download token"})
|
||||
return
|
||||
}
|
||||
|
||||
// Verify token is for the requested backup
|
||||
if downloadToken.BackupID != backupID {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired download token"})
|
||||
return
|
||||
@@ -238,18 +269,28 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
downloadToken.BackupID,
|
||||
)
|
||||
if err != nil {
|
||||
c.backupService.UnregisterDownload(downloadToken.UserID)
|
||||
c.backupService.ReleaseDownloadLock(downloadToken.UserID)
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
rateLimitedReader := backups_download.NewRateLimitedReader(fileReader, rateLimiter)
|
||||
|
||||
heartbeatCtx, cancelHeartbeat := context.WithCancel(context.Background())
|
||||
defer func() {
|
||||
if err := fileReader.Close(); err != nil {
|
||||
cancelHeartbeat()
|
||||
c.backupService.UnregisterDownload(downloadToken.UserID)
|
||||
c.backupService.ReleaseDownloadLock(downloadToken.UserID)
|
||||
if err := rateLimitedReader.Close(); err != nil {
|
||||
fmt.Printf("Error closing file reader: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go c.startDownloadHeartbeat(heartbeatCtx, downloadToken.UserID)
|
||||
|
||||
filename := c.generateBackupFilename(backup, database)
|
||||
|
||||
// Set Content-Length for progress tracking
|
||||
if backup.BackupSizeMb > 0 {
|
||||
sizeBytes := int64(backup.BackupSizeMb * 1024 * 1024)
|
||||
ctx.Header("Content-Length", fmt.Sprintf("%d", sizeBytes))
|
||||
@@ -261,13 +302,12 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
fmt.Sprintf("attachment; filename=\"%s\"", filename),
|
||||
)
|
||||
|
||||
_, err = io.Copy(ctx.Writer, fileReader)
|
||||
_, err = io.Copy(ctx.Writer, rateLimitedReader)
|
||||
if err != nil {
|
||||
fmt.Printf("Error streaming file: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Write audit log after successful download
|
||||
c.backupService.WriteAuditLogForDownload(downloadToken.UserID, backup, database)
|
||||
}
|
||||
|
||||
@@ -276,14 +316,14 @@ type MakeBackupRequest struct {
|
||||
}
|
||||
|
||||
func (c *BackupController) generateBackupFilename(
|
||||
backup *Backup,
|
||||
backup *backups_core.Backup,
|
||||
database *databases.Database,
|
||||
) string {
|
||||
// Format timestamp as YYYY-MM-DD_HH-mm-ss
|
||||
timestamp := backup.CreatedAt.Format("2006-01-02_15-04-05")
|
||||
|
||||
// Sanitize database name for filename (replace spaces and special chars)
|
||||
safeName := sanitizeFilename(database.Name)
|
||||
safeName := files_utils.SanitizeFilename(database.Name)
|
||||
|
||||
// Determine extension based on database type
|
||||
extension := c.getBackupExtension(database.Type)
|
||||
@@ -307,29 +347,16 @@ func (c *BackupController) getBackupExtension(
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeFilename(name string) string {
|
||||
// Replace characters that are invalid in filenames
|
||||
replacer := map[rune]rune{
|
||||
' ': '_',
|
||||
'/': '-',
|
||||
'\\': '-',
|
||||
':': '-',
|
||||
'*': '-',
|
||||
'?': '-',
|
||||
'"': '-',
|
||||
'<': '-',
|
||||
'>': '-',
|
||||
'|': '-',
|
||||
}
|
||||
func (c *BackupController) startDownloadHeartbeat(ctx context.Context, userID uuid.UUID) {
|
||||
ticker := time.NewTicker(backups_download.GetDownloadHeartbeatInterval())
|
||||
defer ticker.Stop()
|
||||
|
||||
result := make([]rune, 0, len(name))
|
||||
for _, char := range name {
|
||||
if replacement, exists := replacer[char]; exists {
|
||||
result = append(result, replacement)
|
||||
} else {
|
||||
result = append(result, char)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
c.backupService.RefreshDownloadLock(userID)
|
||||
}
|
||||
}
|
||||
|
||||
return string(result)
|
||||
}
|
||||
|
||||
@@ -18,7 +18,8 @@ import (
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups/download_token"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
@@ -31,6 +32,7 @@ import (
|
||||
workspaces_models "databasus-backend/internal/features/workspaces/models"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
files_utils "databasus-backend/internal/util/files"
|
||||
test_utils "databasus-backend/internal/util/testing"
|
||||
"databasus-backend/internal/util/tools"
|
||||
)
|
||||
@@ -79,7 +81,7 @@ func Test_GetBackups_PermissionsEnforced(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, _ := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
database, _, storage := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
@@ -121,6 +123,12 @@ func Test_GetBackups_PermissionsEnforced(t *testing.T) {
|
||||
} else {
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -217,6 +225,10 @@ func Test_CreateBackup_PermissionsEnforced(t *testing.T) {
|
||||
} else {
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -260,6 +272,10 @@ func Test_CreateBackup_AuditLogWritten(t *testing.T) {
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Audit log for backup creation not found")
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_DeleteBackup_PermissionsEnforced(t *testing.T) {
|
||||
@@ -313,7 +329,7 @@ func Test_DeleteBackup_PermissionsEnforced(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
@@ -357,6 +373,12 @@ func Test_DeleteBackup_PermissionsEnforced(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, len(response.Backups))
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -366,7 +388,7 @@ func Test_DeleteBackup_AuditLogWritten(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
@@ -397,6 +419,12 @@ func Test_DeleteBackup_AuditLogWritten(t *testing.T) {
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Audit log for backup deletion not found")
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_GenerateDownloadToken_PermissionsEnforced(t *testing.T) {
|
||||
@@ -443,7 +471,7 @@ func Test_GenerateDownloadToken_PermissionsEnforced(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
@@ -478,7 +506,7 @@ func Test_GenerateDownloadToken_PermissionsEnforced(t *testing.T) {
|
||||
)
|
||||
|
||||
if tt.expectSuccess {
|
||||
var response GenerateDownloadTokenResponse
|
||||
var response backups_download.GenerateDownloadTokenResponse
|
||||
err := json.Unmarshal(testResp.Body, &response)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, response.Token)
|
||||
@@ -487,6 +515,12 @@ func Test_GenerateDownloadToken_PermissionsEnforced(t *testing.T) {
|
||||
} else {
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -496,10 +530,10 @@ func Test_DownloadBackup_WithValidToken_Success(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
// Generate download token
|
||||
var tokenResponse GenerateDownloadTokenResponse
|
||||
var tokenResponse backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
@@ -523,6 +557,12 @@ func Test_DownloadBackup_WithValidToken_Success(t *testing.T) {
|
||||
contentDisposition := testResp.Headers.Get("Content-Disposition")
|
||||
assert.Contains(t, contentDisposition, "attachment")
|
||||
assert.Contains(t, contentDisposition, tokenResponse.Filename)
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_DownloadBackup_WithoutToken_Unauthorized(t *testing.T) {
|
||||
@@ -530,7 +570,7 @@ func Test_DownloadBackup_WithoutToken_Unauthorized(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
// Try to download without token
|
||||
testResp := test_utils.MakeGetRequest(
|
||||
@@ -542,6 +582,12 @@ func Test_DownloadBackup_WithoutToken_Unauthorized(t *testing.T) {
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp.Body), "download token is required")
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_DownloadBackup_WithInvalidToken_Unauthorized(t *testing.T) {
|
||||
@@ -549,7 +595,7 @@ func Test_DownloadBackup_WithInvalidToken_Unauthorized(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
// Try to download with invalid token
|
||||
testResp := test_utils.MakeGetRequest(
|
||||
@@ -561,6 +607,12 @@ func Test_DownloadBackup_WithInvalidToken_Unauthorized(t *testing.T) {
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp.Body), "invalid or expired download token")
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_DownloadBackup_WithExpiredToken_Unauthorized(t *testing.T) {
|
||||
@@ -568,7 +620,7 @@ func Test_DownloadBackup_WithExpiredToken_Unauthorized(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
// Get user for token generation
|
||||
userService := users_services.GetUserService()
|
||||
@@ -610,6 +662,12 @@ func Test_DownloadBackup_WithExpiredToken_Unauthorized(t *testing.T) {
|
||||
}
|
||||
}
|
||||
assert.False(t, found, "Audit log should NOT be created for failed download with expired token")
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_DownloadBackup_TokenUsedOnce_CannotReuseToken(t *testing.T) {
|
||||
@@ -617,10 +675,10 @@ func Test_DownloadBackup_TokenUsedOnce_CannotReuseToken(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
// Generate download token
|
||||
var tokenResponse GenerateDownloadTokenResponse
|
||||
var tokenResponse backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
@@ -650,6 +708,12 @@ func Test_DownloadBackup_TokenUsedOnce_CannotReuseToken(t *testing.T) {
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp.Body), "invalid or expired download token")
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_DownloadBackup_WithDifferentBackupToken_Unauthorized(t *testing.T) {
|
||||
@@ -683,7 +747,7 @@ func Test_DownloadBackup_WithDifferentBackupToken_Unauthorized(t *testing.T) {
|
||||
backup2 := createTestBackup(database2, owner)
|
||||
|
||||
// Generate token for backup1
|
||||
var tokenResponse GenerateDownloadTokenResponse
|
||||
var tokenResponse backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
@@ -704,6 +768,13 @@ func Test_DownloadBackup_WithDifferentBackupToken_Unauthorized(t *testing.T) {
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp.Body), "invalid or expired download token")
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database1)
|
||||
databases.RemoveTestDatabase(database2)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_DownloadBackup_AuditLogWritten(t *testing.T) {
|
||||
@@ -711,10 +782,10 @@ func Test_DownloadBackup_AuditLogWritten(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
// Generate download token
|
||||
var tokenResponse GenerateDownloadTokenResponse
|
||||
var tokenResponse backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
@@ -755,6 +826,12 @@ func Test_DownloadBackup_AuditLogWritten(t *testing.T) {
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Audit log for backup download not found")
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_DownloadBackup_ProperFilenameForPostgreSQL(t *testing.T) {
|
||||
@@ -806,7 +883,7 @@ func Test_DownloadBackup_ProperFilenameForPostgreSQL(t *testing.T) {
|
||||
backup := createTestBackup(database, owner)
|
||||
|
||||
// Generate download token
|
||||
var tokenResponse GenerateDownloadTokenResponse
|
||||
var tokenResponse backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
@@ -855,6 +932,12 @@ func Test_DownloadBackup_ProperFilenameForPostgreSQL(t *testing.T) {
|
||||
contentDisposition,
|
||||
"Filename should contain timestamp",
|
||||
)
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -874,7 +957,7 @@ func Test_SanitizeFilename(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
result := sanitizeFilename(tt.input)
|
||||
result := files_utils.SanitizeFilename(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
@@ -897,22 +980,22 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
|
||||
_, err = configService.SaveBackupConfig(config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backup := &Backup{
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: BackupStatusInProgress,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
BackupSizeMb: 0,
|
||||
BackupDurationMs: 0,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
repo := &BackupRepository{}
|
||||
repo := &backups_core.BackupRepository{}
|
||||
err = repo.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Register a cancellable context for the backup
|
||||
GetBackupService().backupContextManager.RegisterBackup(backup.ID, func() {})
|
||||
GetBackupService().taskCancelManager.RegisterTask(backup.ID, func() {})
|
||||
|
||||
resp := test_utils.MakePostRequest(
|
||||
t,
|
||||
@@ -947,6 +1030,219 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
|
||||
}
|
||||
}
|
||||
assert.True(t, foundCancelLog, "Cancel audit log should be created")
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_ConcurrentDownloadPrevention(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
var token1Response backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token1Response,
|
||||
)
|
||||
|
||||
var token2Response backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token2Response,
|
||||
)
|
||||
|
||||
downloadInProgress := make(chan bool, 1)
|
||||
downloadComplete := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf(
|
||||
"/api/v1/backups/%s/file?token=%s",
|
||||
backup.ID.String(),
|
||||
token1Response.Token,
|
||||
),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
downloadComplete <- true
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
service := GetBackupService()
|
||||
if !service.IsDownloadInProgress(owner.UserID) {
|
||||
t.Log("Warning: First download completed before we could test concurrency")
|
||||
<-downloadComplete
|
||||
|
||||
// Cleanup before early return
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
return
|
||||
}
|
||||
|
||||
downloadInProgress <- true
|
||||
|
||||
resp := test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), token2Response.Token),
|
||||
"",
|
||||
http.StatusConflict,
|
||||
)
|
||||
|
||||
var errorResponse map[string]string
|
||||
err := json.Unmarshal(resp.Body, &errorResponse)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, errorResponse["error"], "download already in progress")
|
||||
|
||||
<-downloadComplete
|
||||
<-downloadInProgress
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
var token3Response backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token3Response,
|
||||
)
|
||||
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), token3Response.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
|
||||
t.Log("Database:", database.Name)
|
||||
t.Log(
|
||||
"Successfully prevented concurrent downloads and allowed subsequent downloads after completion",
|
||||
)
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_GenerateDownloadToken_BlockedWhenDownloadInProgress(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
var token1Response backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token1Response,
|
||||
)
|
||||
|
||||
downloadComplete := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf(
|
||||
"/api/v1/backups/%s/file?token=%s",
|
||||
backup.ID.String(),
|
||||
token1Response.Token,
|
||||
),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
downloadComplete <- true
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
service := GetBackupService()
|
||||
if !service.IsDownloadInProgress(owner.UserID) {
|
||||
t.Log("Warning: First download completed before we could test token generation blocking")
|
||||
<-downloadComplete
|
||||
|
||||
// Cleanup before early return
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
return
|
||||
}
|
||||
|
||||
resp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusConflict,
|
||||
)
|
||||
|
||||
var errorResponse map[string]string
|
||||
err := json.Unmarshal(resp.Body, &errorResponse)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, errorResponse["error"], "download already in progress")
|
||||
|
||||
<-downloadComplete
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
var token2Response backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token2Response,
|
||||
)
|
||||
|
||||
assert.NotEmpty(t, token2Response.Token)
|
||||
assert.NotEqual(t, token1Response.Token, token2Response.Token)
|
||||
|
||||
t.Log("Database:", database.Name)
|
||||
t.Log(
|
||||
"Successfully blocked token generation during download and allowed generation after completion",
|
||||
)
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
@@ -972,7 +1268,7 @@ func createTestDatabase(
|
||||
Type: databases.DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
@@ -1038,7 +1334,7 @@ func createTestDatabaseWithBackups(
|
||||
workspace *workspaces_models.Workspace,
|
||||
owner *users_dto.SignInResponseDTO,
|
||||
router *gin.Engine,
|
||||
) (*databases.Database, *Backup) {
|
||||
) (*databases.Database, *backups_core.Backup, *storages.Storage) {
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
@@ -1058,35 +1354,48 @@ func createTestDatabaseWithBackups(
|
||||
|
||||
backup := createTestBackup(database, owner)
|
||||
|
||||
return database, backup
|
||||
return database, backup, storage
|
||||
}
|
||||
|
||||
func createTestBackup(
|
||||
database *databases.Database,
|
||||
owner *users_dto.SignInResponseDTO,
|
||||
) *Backup {
|
||||
) *backups_core.Backup {
|
||||
userService := users_services.GetUserService()
|
||||
user, err := userService.GetUserFromToken(owner.Token)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
storages, err := storages.GetStorageService().GetStorages(user, *database.WorkspaceID)
|
||||
if err != nil || len(storages) == 0 {
|
||||
loadedStorages, err := storages.GetStorageService().GetStorages(user, *database.WorkspaceID)
|
||||
if err != nil || len(loadedStorages) == 0 {
|
||||
panic("No storage found for workspace")
|
||||
}
|
||||
|
||||
backup := &Backup{
|
||||
// Filter out system storages
|
||||
var nonSystemStorages []*storages.Storage
|
||||
for _, storage := range loadedStorages {
|
||||
if !storage.IsSystem {
|
||||
nonSystemStorages = append(nonSystemStorages, storage)
|
||||
}
|
||||
}
|
||||
if len(nonSystemStorages) == 0 {
|
||||
panic("No non-system storage found for workspace")
|
||||
}
|
||||
|
||||
storages := nonSystemStorages
|
||||
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storages[0].ID,
|
||||
Status: BackupStatusCompleted,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10.5,
|
||||
BackupDurationMs: 1000,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
repo := &BackupRepository{}
|
||||
repo := &backups_core.BackupRepository{}
|
||||
if err := repo.Save(backup); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@@ -1099,7 +1408,7 @@ func createTestBackup(
|
||||
context.Background(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
logger,
|
||||
backup.ID,
|
||||
backup.ID.String(),
|
||||
reader,
|
||||
); err != nil {
|
||||
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
|
||||
@@ -1116,7 +1425,7 @@ func createExpiredDownloadToken(backupID, userID uuid.UUID) string {
|
||||
}
|
||||
|
||||
// Manually update the token to be expired
|
||||
repo := &download_token.DownloadTokenRepository{}
|
||||
repo := &backups_download.DownloadTokenRepository{}
|
||||
downloadToken, err := repo.FindByToken(token)
|
||||
if err != nil || downloadToken == nil {
|
||||
panic(fmt.Sprintf("Failed to find generated token: %v", err))
|
||||
@@ -1130,3 +1439,285 @@ func createExpiredDownloadToken(backupID, userID uuid.UUID) string {
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
func Test_BandwidthThrottling_SingleDownload_Uses75Percent(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
bandwidthManager := backups_download.GetBandwidthManager()
|
||||
initialCount := bandwidthManager.GetActiveDownloadCount()
|
||||
|
||||
var tokenResponse backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&tokenResponse,
|
||||
)
|
||||
|
||||
downloadStarted := make(chan bool, 1)
|
||||
downloadComplete := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf(
|
||||
"/api/v1/backups/%s/file?token=%s",
|
||||
backup.ID.String(),
|
||||
tokenResponse.Token,
|
||||
),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
downloadComplete <- true
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
activeCount := bandwidthManager.GetActiveDownloadCount()
|
||||
if activeCount > initialCount {
|
||||
downloadStarted <- true
|
||||
assert.Equal(t, initialCount+1, activeCount, "Should have one active download")
|
||||
}
|
||||
|
||||
<-downloadComplete
|
||||
if len(downloadStarted) > 0 {
|
||||
<-downloadStarted
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
finalCount := bandwidthManager.GetActiveDownloadCount()
|
||||
assert.Equal(t, initialCount, finalCount, "Download should be unregistered after completion")
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_BandwidthThrottling_MultipleDownloads_ShareBandwidth(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
owner3 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner1, router)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
owner2,
|
||||
users_enums.WorkspaceRoleMember,
|
||||
owner1.Token,
|
||||
router,
|
||||
)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
owner3,
|
||||
users_enums.WorkspaceRoleMember,
|
||||
owner1.Token,
|
||||
router,
|
||||
)
|
||||
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner1.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
configService := backups_config.GetBackupConfigService()
|
||||
config, err := configService.GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
config.IsBackupsEnabled = true
|
||||
config.StorageID = &storage.ID
|
||||
config.Storage = storage
|
||||
_, err = configService.SaveBackupConfig(config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backup1 := createTestBackup(database, owner1)
|
||||
backup2 := createTestBackup(database, owner2)
|
||||
backup3 := createTestBackup(database, owner3)
|
||||
|
||||
var token1, token2, token3 backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup1.ID.String()),
|
||||
"Bearer "+owner1.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token1,
|
||||
)
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup2.ID.String()),
|
||||
"Bearer "+owner2.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token2,
|
||||
)
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup3.ID.String()),
|
||||
"Bearer "+owner3.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token3,
|
||||
)
|
||||
|
||||
bandwidthManager := backups_download.GetBandwidthManager()
|
||||
initialCount := bandwidthManager.GetActiveDownloadCount()
|
||||
|
||||
complete1 := make(chan bool, 1)
|
||||
complete2 := make(chan bool, 1)
|
||||
complete3 := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup1.ID.String(), token1.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
complete1 <- true
|
||||
}()
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup2.ID.String(), token2.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
complete2 <- true
|
||||
}()
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup3.ID.String(), token3.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
complete3 <- true
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
<-complete1
|
||||
<-complete2
|
||||
<-complete3
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
finalCount := bandwidthManager.GetActiveDownloadCount()
|
||||
assert.Equal(t, initialCount, finalCount, "All downloads should be unregistered")
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_BandwidthThrottling_DynamicAdjustment(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner1, router)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
owner2,
|
||||
users_enums.WorkspaceRoleMember,
|
||||
owner1.Token,
|
||||
router,
|
||||
)
|
||||
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner1.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
configService := backups_config.GetBackupConfigService()
|
||||
config, err := configService.GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
config.IsBackupsEnabled = true
|
||||
config.StorageID = &storage.ID
|
||||
config.Storage = storage
|
||||
_, err = configService.SaveBackupConfig(config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backup1 := createTestBackup(database, owner1)
|
||||
backup2 := createTestBackup(database, owner2)
|
||||
|
||||
var token1, token2 backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup1.ID.String()),
|
||||
"Bearer "+owner1.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token1,
|
||||
)
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup2.ID.String()),
|
||||
"Bearer "+owner2.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token2,
|
||||
)
|
||||
|
||||
bandwidthManager := backups_download.GetBandwidthManager()
|
||||
initialCount := bandwidthManager.GetActiveDownloadCount()
|
||||
|
||||
complete1 := make(chan bool, 1)
|
||||
complete2 := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup1.ID.String(), token1.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
complete1 <- true
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup2.ID.String(), token2.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
complete2 <- true
|
||||
}()
|
||||
|
||||
<-complete1
|
||||
<-complete2
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
finalCount := bandwidthManager.GetActiveDownloadCount()
|
||||
assert.Equal(t, initialCount, finalCount, "All downloads completed and unregistered")
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package backups
|
||||
package backups_core
|
||||
|
||||
type BackupStatus string
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package backups
|
||||
package backups_core
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -8,8 +8,6 @@ import (
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type NotificationSender interface {
|
||||
@@ -23,7 +21,7 @@ type NotificationSender interface {
|
||||
type CreateBackupUsecase interface {
|
||||
Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
@@ -1,4 +1,4 @@
|
||||
package backups
|
||||
package backups_core
|
||||
|
||||
import (
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
@@ -8,13 +8,15 @@ import (
|
||||
)
|
||||
|
||||
type Backup struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
|
||||
FileName string `json:"fileName" gorm:"column:file_name;type:text;not null"`
|
||||
|
||||
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;not null"`
|
||||
StorageID uuid.UUID `json:"storageId" gorm:"column:storage_id;type:uuid;not null"`
|
||||
|
||||
Status BackupStatus `json:"status" gorm:"column:status;not null"`
|
||||
FailMessage *string `json:"failMessage" gorm:"column:fail_message"`
|
||||
IsSkipRetry bool `json:"isSkipRetry" gorm:"column:is_skip_retry;type:boolean;not null"`
|
||||
|
||||
BackupSizeMb float64 `json:"backupSizeMb" gorm:"column:backup_size_mb;default:0"`
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package backups
|
||||
package backups_core
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/storage"
|
||||
@@ -212,3 +212,36 @@ func (r *BackupRepository) CountByDatabaseID(databaseID uuid.UUID) (int64, error
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *BackupRepository) GetTotalSizeByDatabase(databaseID uuid.UUID) (float64, error) {
|
||||
var totalSize float64
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Model(&Backup{}).
|
||||
Select("COALESCE(SUM(backup_size_mb), 0)").
|
||||
Where("database_id = ? AND status != ?", databaseID, BackupStatusInProgress).
|
||||
Scan(&totalSize).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return totalSize, nil
|
||||
}
|
||||
|
||||
func (r *BackupRepository) FindOldestByDatabaseExcludingInProgress(
|
||||
databaseID uuid.UUID,
|
||||
limit int,
|
||||
) ([]*Backup, error) {
|
||||
var backups []*Backup
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Where("database_id = ? AND status != ?", databaseID, BackupStatusInProgress).
|
||||
Order("created_at ASC").
|
||||
Limit(limit).
|
||||
Find(&backups).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return backups, nil
|
||||
}
|
||||
@@ -1,24 +1,28 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"time"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups/download_token"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
"databasus-backend/internal/features/backups/backups/usecases"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var backupRepository = &BackupRepository{}
|
||||
var backupRepository = &backups_core.BackupRepository{}
|
||||
|
||||
var backupContextManager = NewBackupContextManager()
|
||||
var taskCancelManager = task_cancellation.GetTaskCancelManager()
|
||||
|
||||
var backupService = &BackupService{
|
||||
databases.GetDatabaseService(),
|
||||
@@ -31,35 +35,17 @@ var backupService = &BackupService{
|
||||
encryption.GetFieldEncryptor(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
[]backups_core.BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
audit_logs.GetAuditLogService(),
|
||||
backupContextManager,
|
||||
download_token.GetDownloadTokenService(),
|
||||
}
|
||||
|
||||
var backupBackgroundService = &BackupBackgroundService{
|
||||
backupService,
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
time.Now().UTC(),
|
||||
logger.GetLogger(),
|
||||
taskCancelManager,
|
||||
backups_download.GetDownloadTokenService(),
|
||||
backuping.GetBackupsScheduler(),
|
||||
backuping.GetBackupCleaner(),
|
||||
}
|
||||
|
||||
var backupController = &BackupController{
|
||||
backupService,
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
backups_config.
|
||||
GetBackupConfigService().
|
||||
SetDatabaseStorageChangeListener(backupService)
|
||||
|
||||
databases.GetDatabaseService().AddDbRemoveListener(backupService)
|
||||
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
|
||||
|
||||
backupContextManager.StartSubscription()
|
||||
backupService: backupService,
|
||||
}
|
||||
|
||||
func GetBackupService() *BackupService {
|
||||
@@ -70,10 +56,26 @@ func GetBackupController() *BackupController {
|
||||
return backupController
|
||||
}
|
||||
|
||||
func GetBackupBackgroundService() *BackupBackgroundService {
|
||||
return backupBackgroundService
|
||||
}
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func GetDownloadTokenBackgroundService() *download_token.DownloadTokenBackgroundService {
|
||||
return download_token.GetDownloadTokenBackgroundService()
|
||||
func SetupDependencies() {
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
backups_config.
|
||||
GetBackupConfigService().
|
||||
SetDatabaseStorageChangeListener(backupService)
|
||||
|
||||
databases.GetDatabaseService().AddDbRemoveListener(backupService)
|
||||
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DownloadTokenBackgroundService struct {
|
||||
downloadTokenService *DownloadTokenService
|
||||
logger *slog.Logger
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *DownloadTokenBackgroundService) Run(ctx context.Context) {
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
|
||||
s.logger.Info("Starting download token cleanup background service")
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
|
||||
s.logger.Error("Failed to clean expired download tokens", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type BandwidthManager struct {
|
||||
mu sync.RWMutex
|
||||
activeDownloads map[uuid.UUID]*activeDownload
|
||||
maxTotalBytesPerSecond int64
|
||||
bytesPerSecondPerDownload int64
|
||||
}
|
||||
|
||||
type activeDownload struct {
|
||||
userID uuid.UUID
|
||||
rateLimiter *RateLimiter
|
||||
}
|
||||
|
||||
func NewBandwidthManager(throughputMBs int) *BandwidthManager {
|
||||
// Use 75% of total throughput
|
||||
maxBytes := int64(throughputMBs) * 1024 * 1024 * 75 / 100
|
||||
|
||||
return &BandwidthManager{
|
||||
activeDownloads: make(map[uuid.UUID]*activeDownload),
|
||||
maxTotalBytesPerSecond: maxBytes,
|
||||
bytesPerSecondPerDownload: maxBytes,
|
||||
}
|
||||
}
|
||||
|
||||
func (bm *BandwidthManager) RegisterDownload(userID uuid.UUID) (*RateLimiter, error) {
|
||||
bm.mu.Lock()
|
||||
defer bm.mu.Unlock()
|
||||
|
||||
if _, exists := bm.activeDownloads[userID]; exists {
|
||||
return nil, fmt.Errorf("download already registered for user %s", userID)
|
||||
}
|
||||
|
||||
rateLimiter := NewRateLimiter(bm.bytesPerSecondPerDownload)
|
||||
|
||||
bm.activeDownloads[userID] = &activeDownload{
|
||||
userID: userID,
|
||||
rateLimiter: rateLimiter,
|
||||
}
|
||||
|
||||
bm.recalculateRates()
|
||||
|
||||
return rateLimiter, nil
|
||||
}
|
||||
|
||||
func (bm *BandwidthManager) UnregisterDownload(userID uuid.UUID) {
|
||||
bm.mu.Lock()
|
||||
defer bm.mu.Unlock()
|
||||
|
||||
delete(bm.activeDownloads, userID)
|
||||
bm.recalculateRates()
|
||||
}
|
||||
|
||||
func (bm *BandwidthManager) GetActiveDownloadCount() int {
|
||||
bm.mu.RLock()
|
||||
defer bm.mu.RUnlock()
|
||||
return len(bm.activeDownloads)
|
||||
}
|
||||
|
||||
func (bm *BandwidthManager) recalculateRates() {
|
||||
activeCount := len(bm.activeDownloads)
|
||||
|
||||
if activeCount == 0 {
|
||||
bm.bytesPerSecondPerDownload = bm.maxTotalBytesPerSecond
|
||||
return
|
||||
}
|
||||
|
||||
newRate := bm.maxTotalBytesPerSecond / int64(activeCount)
|
||||
bm.bytesPerSecondPerDownload = newRate
|
||||
|
||||
for _, download := range bm.activeDownloads {
|
||||
download.rateLimiter.UpdateRate(newRate)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,150 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_BandwidthManager_RegisterSingleDownload(t *testing.T) {
|
||||
throughputMBs := 100
|
||||
manager := NewBandwidthManager(throughputMBs)
|
||||
|
||||
expectedBytesPerSec := int64(100 * 1024 * 1024 * 75 / 100)
|
||||
assert.Equal(t, expectedBytesPerSec, manager.maxTotalBytesPerSecond)
|
||||
assert.Equal(t, expectedBytesPerSec, manager.bytesPerSecondPerDownload)
|
||||
|
||||
userID := uuid.New()
|
||||
rateLimiter, err := manager.RegisterDownload(userID)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, rateLimiter)
|
||||
|
||||
assert.Equal(t, 1, manager.GetActiveDownloadCount())
|
||||
assert.Equal(t, expectedBytesPerSec, manager.bytesPerSecondPerDownload)
|
||||
assert.Equal(t, expectedBytesPerSec, rateLimiter.bytesPerSecond)
|
||||
}
|
||||
|
||||
func Test_BandwidthManager_RegisterMultipleDownloads_BandwidthShared(t *testing.T) {
|
||||
throughputMBs := 100
|
||||
manager := NewBandwidthManager(throughputMBs)
|
||||
|
||||
maxBytes := int64(100 * 1024 * 1024 * 75 / 100)
|
||||
|
||||
user1 := uuid.New()
|
||||
rateLimiter1, err := manager.RegisterDownload(user1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, maxBytes, rateLimiter1.bytesPerSecond)
|
||||
|
||||
user2 := uuid.New()
|
||||
rateLimiter2, err := manager.RegisterDownload(user2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
expectedPerDownload := maxBytes / 2
|
||||
assert.Equal(t, expectedPerDownload, manager.bytesPerSecondPerDownload)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter2.bytesPerSecond)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter2.bytesPerSecond)
|
||||
|
||||
user3 := uuid.New()
|
||||
rateLimiter3, err := manager.RegisterDownload(user3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
expectedPerDownload = maxBytes / 3
|
||||
assert.Equal(t, expectedPerDownload, manager.bytesPerSecondPerDownload)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter2.bytesPerSecond)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter3.bytesPerSecond)
|
||||
assert.Equal(t, 3, manager.GetActiveDownloadCount())
|
||||
}
|
||||
|
||||
func Test_BandwidthManager_UnregisterDownload_BandwidthRebalanced(t *testing.T) {
|
||||
throughputMBs := 100
|
||||
manager := NewBandwidthManager(throughputMBs)
|
||||
|
||||
maxBytes := int64(100 * 1024 * 1024 * 75 / 100)
|
||||
|
||||
user1 := uuid.New()
|
||||
rateLimiter1, _ := manager.RegisterDownload(user1)
|
||||
|
||||
user2 := uuid.New()
|
||||
_, _ = manager.RegisterDownload(user2)
|
||||
|
||||
user3 := uuid.New()
|
||||
rateLimiter3, _ := manager.RegisterDownload(user3)
|
||||
|
||||
assert.Equal(t, 3, manager.GetActiveDownloadCount())
|
||||
expectedPerDownload := maxBytes / 3
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
|
||||
|
||||
manager.UnregisterDownload(user2)
|
||||
|
||||
assert.Equal(t, 2, manager.GetActiveDownloadCount())
|
||||
expectedPerDownload = maxBytes / 2
|
||||
assert.Equal(t, expectedPerDownload, manager.bytesPerSecondPerDownload)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter3.bytesPerSecond)
|
||||
|
||||
manager.UnregisterDownload(user1)
|
||||
|
||||
assert.Equal(t, 1, manager.GetActiveDownloadCount())
|
||||
assert.Equal(t, maxBytes, manager.bytesPerSecondPerDownload)
|
||||
assert.Equal(t, maxBytes, rateLimiter3.bytesPerSecond)
|
||||
|
||||
manager.UnregisterDownload(user3)
|
||||
assert.Equal(t, 0, manager.GetActiveDownloadCount())
|
||||
assert.Equal(t, maxBytes, manager.bytesPerSecondPerDownload)
|
||||
}
|
||||
|
||||
func Test_BandwidthManager_RegisterDuplicateUser_ReturnsError(t *testing.T) {
|
||||
manager := NewBandwidthManager(100)
|
||||
|
||||
userID := uuid.New()
|
||||
_, err := manager.RegisterDownload(userID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = manager.RegisterDownload(userID)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "download already registered")
|
||||
}
|
||||
|
||||
func Test_RateLimiter_TokenBucketBasic(t *testing.T) {
|
||||
bytesPerSec := int64(1024 * 1024)
|
||||
limiter := NewRateLimiter(bytesPerSec)
|
||||
|
||||
assert.Equal(t, bytesPerSec, limiter.bytesPerSecond)
|
||||
assert.Equal(t, bytesPerSec*2, limiter.bucketSize)
|
||||
|
||||
start := time.Now()
|
||||
limiter.Wait(512 * 1024)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.Less(t, elapsed, 100*time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_RateLimiter_UpdateRate(t *testing.T) {
|
||||
limiter := NewRateLimiter(1024 * 1024)
|
||||
|
||||
assert.Equal(t, int64(1024*1024), limiter.bytesPerSecond)
|
||||
|
||||
newRate := int64(2 * 1024 * 1024)
|
||||
limiter.UpdateRate(newRate)
|
||||
|
||||
assert.Equal(t, newRate, limiter.bytesPerSecond)
|
||||
assert.Equal(t, newRate*2, limiter.bucketSize)
|
||||
}
|
||||
|
||||
func Test_RateLimiter_ThrottlesCorrectly(t *testing.T) {
|
||||
bytesPerSec := int64(1024 * 1024)
|
||||
limiter := NewRateLimiter(bytesPerSec)
|
||||
|
||||
limiter.availableTokens = 0
|
||||
|
||||
start := time.Now()
|
||||
limiter.Wait(bytesPerSec / 2)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.GreaterOrEqual(t, elapsed, 400*time.Millisecond)
|
||||
assert.LessOrEqual(t, elapsed, 700*time.Millisecond)
|
||||
}
|
||||
53
backend/internal/features/backups/backups/download/di.go
Normal file
53
backend/internal/features/backups/backups/download/di.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var downloadTokenRepository = &DownloadTokenRepository{}
|
||||
|
||||
var downloadTracker = NewDownloadTracker(cache_utils.GetValkeyClient())
|
||||
|
||||
var bandwidthManager *BandwidthManager
|
||||
var downloadTokenService *DownloadTokenService
|
||||
var downloadTokenBackgroundService *DownloadTokenBackgroundService
|
||||
|
||||
func init() {
|
||||
env := config.GetEnv()
|
||||
throughputMBs := env.NodeNetworkThroughputMBs
|
||||
if throughputMBs == 0 {
|
||||
throughputMBs = 125
|
||||
}
|
||||
bandwidthManager = NewBandwidthManager(throughputMBs)
|
||||
|
||||
downloadTokenService = &DownloadTokenService{
|
||||
downloadTokenRepository,
|
||||
logger.GetLogger(),
|
||||
downloadTracker,
|
||||
bandwidthManager,
|
||||
}
|
||||
|
||||
downloadTokenBackgroundService = &DownloadTokenBackgroundService{
|
||||
downloadTokenService: downloadTokenService,
|
||||
logger: logger.GetLogger(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func GetDownloadTokenService() *DownloadTokenService {
|
||||
return downloadTokenService
|
||||
}
|
||||
|
||||
func GetDownloadTokenBackgroundService() *DownloadTokenBackgroundService {
|
||||
return downloadTokenBackgroundService
|
||||
}
|
||||
|
||||
func GetBandwidthManager() *BandwidthManager {
|
||||
return bandwidthManager
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
package backups_download
|
||||
|
||||
import "github.com/google/uuid"
|
||||
|
||||
type GenerateDownloadTokenResponse struct {
|
||||
Token string `json:"token"`
|
||||
Filename string `json:"filename"`
|
||||
BackupID uuid.UUID `json:"backupId"`
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package download_token
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"time"
|
||||
@@ -0,0 +1,101 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RateLimiter struct {
|
||||
mu sync.Mutex
|
||||
bytesPerSecond int64
|
||||
bucketSize int64
|
||||
availableTokens float64
|
||||
lastRefill time.Time
|
||||
}
|
||||
|
||||
func NewRateLimiter(bytesPerSecond int64) *RateLimiter {
|
||||
if bytesPerSecond <= 0 {
|
||||
bytesPerSecond = 1024 * 1024 * 100
|
||||
}
|
||||
|
||||
return &RateLimiter{
|
||||
bytesPerSecond: bytesPerSecond,
|
||||
bucketSize: bytesPerSecond * 2,
|
||||
availableTokens: float64(bytesPerSecond * 2),
|
||||
lastRefill: time.Now().UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) UpdateRate(bytesPerSecond int64) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
if bytesPerSecond <= 0 {
|
||||
bytesPerSecond = 1024 * 1024 * 100
|
||||
}
|
||||
|
||||
rl.bytesPerSecond = bytesPerSecond
|
||||
rl.bucketSize = bytesPerSecond * 2
|
||||
|
||||
if rl.availableTokens > float64(rl.bucketSize) {
|
||||
rl.availableTokens = float64(rl.bucketSize)
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) Wait(bytes int64) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
for {
|
||||
now := time.Now().UTC()
|
||||
elapsed := now.Sub(rl.lastRefill).Seconds()
|
||||
|
||||
tokensToAdd := elapsed * float64(rl.bytesPerSecond)
|
||||
rl.availableTokens += tokensToAdd
|
||||
if rl.availableTokens > float64(rl.bucketSize) {
|
||||
rl.availableTokens = float64(rl.bucketSize)
|
||||
}
|
||||
rl.lastRefill = now
|
||||
|
||||
if rl.availableTokens >= float64(bytes) {
|
||||
rl.availableTokens -= float64(bytes)
|
||||
return
|
||||
}
|
||||
|
||||
tokensNeeded := float64(bytes) - rl.availableTokens
|
||||
waitTime := time.Duration(tokensNeeded/float64(rl.bytesPerSecond)*1000) * time.Millisecond
|
||||
|
||||
if waitTime < time.Millisecond {
|
||||
waitTime = time.Millisecond
|
||||
}
|
||||
|
||||
rl.mu.Unlock()
|
||||
time.Sleep(waitTime)
|
||||
rl.mu.Lock()
|
||||
}
|
||||
}
|
||||
|
||||
type RateLimitedReader struct {
|
||||
reader io.ReadCloser
|
||||
rateLimiter *RateLimiter
|
||||
}
|
||||
|
||||
func NewRateLimitedReader(reader io.ReadCloser, limiter *RateLimiter) *RateLimitedReader {
|
||||
return &RateLimitedReader{
|
||||
reader: reader,
|
||||
rateLimiter: limiter,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RateLimitedReader) Read(p []byte) (n int, err error) {
|
||||
n, err = r.reader.Read(p)
|
||||
if n > 0 {
|
||||
r.rateLimiter.Wait(int64(n))
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *RateLimitedReader) Close() error {
|
||||
return r.reader.Close()
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package download_token
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
105
backend/internal/features/backups/backups/download/service.go
Normal file
105
backend/internal/features/backups/backups/download/service.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type DownloadTokenService struct {
|
||||
repository *DownloadTokenRepository
|
||||
logger *slog.Logger
|
||||
downloadTracker *DownloadTracker
|
||||
bandwidthManager *BandwidthManager
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) Generate(backupID, userID uuid.UUID) (string, error) {
|
||||
if s.downloadTracker.IsDownloadInProgress(userID) {
|
||||
return "", ErrDownloadAlreadyInProgress
|
||||
}
|
||||
|
||||
token := GenerateSecureToken()
|
||||
|
||||
downloadToken := &DownloadToken{
|
||||
Token: token,
|
||||
BackupID: backupID,
|
||||
UserID: userID,
|
||||
ExpiresAt: time.Now().UTC().Add(5 * time.Minute),
|
||||
Used: false,
|
||||
}
|
||||
|
||||
if err := s.repository.Create(downloadToken); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
s.logger.Info("Generated download token", "backupId", backupID, "userId", userID)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) ValidateAndConsume(
|
||||
token string,
|
||||
) (*DownloadToken, *RateLimiter, error) {
|
||||
dt, err := s.repository.FindByToken(token)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if dt == nil {
|
||||
return nil, nil, errors.New("invalid token")
|
||||
}
|
||||
|
||||
if dt.Used {
|
||||
return nil, nil, errors.New("token already used")
|
||||
}
|
||||
|
||||
if time.Now().UTC().After(dt.ExpiresAt) {
|
||||
return nil, nil, errors.New("token expired")
|
||||
}
|
||||
|
||||
if err := s.downloadTracker.AcquireDownloadLock(dt.UserID); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
rateLimiter, err := s.bandwidthManager.RegisterDownload(dt.UserID)
|
||||
if err != nil {
|
||||
s.downloadTracker.ReleaseDownloadLock(dt.UserID)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
dt.Used = true
|
||||
if err := s.repository.Update(dt); err != nil {
|
||||
s.logger.Error("Failed to mark token as used", "error", err)
|
||||
}
|
||||
|
||||
s.logger.Info("Token validated and consumed", "backupId", dt.BackupID, "userId", dt.UserID)
|
||||
return dt, rateLimiter, nil
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) RefreshDownloadLock(userID uuid.UUID) {
|
||||
s.downloadTracker.RefreshDownloadLock(userID)
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) ReleaseDownloadLock(userID uuid.UUID) {
|
||||
s.downloadTracker.ReleaseDownloadLock(userID)
|
||||
s.logger.Info("Released download lock", "userId", userID)
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) IsDownloadInProgress(userID uuid.UUID) bool {
|
||||
return s.downloadTracker.IsDownloadInProgress(userID)
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) UnregisterDownload(userID uuid.UUID) {
|
||||
s.bandwidthManager.UnregisterDownload(userID)
|
||||
s.logger.Info("Unregistered from bandwidth manager", "userId", userID)
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) CleanExpiredTokens() error {
|
||||
now := time.Now().UTC()
|
||||
if err := s.repository.DeleteExpired(now); err != nil {
|
||||
return err
|
||||
}
|
||||
s.logger.Debug("Cleaned expired download tokens")
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/valkey-io/valkey-go"
|
||||
)
|
||||
|
||||
const (
|
||||
downloadLockPrefix = "backup_download_lock:"
|
||||
downloadLockTTL = 5 * time.Second
|
||||
downloadLockValue = "1"
|
||||
downloadHeartbeatDelay = 3 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDownloadAlreadyInProgress = errors.New("download already in progress for this user")
|
||||
)
|
||||
|
||||
type DownloadTracker struct {
|
||||
cache *cache_utils.CacheUtil[string]
|
||||
}
|
||||
|
||||
func NewDownloadTracker(client valkey.Client) *DownloadTracker {
|
||||
return &DownloadTracker{
|
||||
cache: cache_utils.NewCacheUtil[string](client, downloadLockPrefix),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *DownloadTracker) AcquireDownloadLock(userID uuid.UUID) error {
|
||||
key := userID.String()
|
||||
|
||||
existingLock := t.cache.Get(key)
|
||||
if existingLock != nil {
|
||||
return ErrDownloadAlreadyInProgress
|
||||
}
|
||||
|
||||
value := downloadLockValue
|
||||
t.cache.Set(key, &value)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *DownloadTracker) RefreshDownloadLock(userID uuid.UUID) {
|
||||
key := userID.String()
|
||||
value := downloadLockValue
|
||||
t.cache.Set(key, &value)
|
||||
}
|
||||
|
||||
func (t *DownloadTracker) ReleaseDownloadLock(userID uuid.UUID) {
|
||||
key := userID.String()
|
||||
t.cache.Invalidate(key)
|
||||
}
|
||||
|
||||
func (t *DownloadTracker) IsDownloadInProgress(userID uuid.UUID) bool {
|
||||
key := userID.String()
|
||||
existingLock := t.cache.Get(key)
|
||||
return existingLock != nil
|
||||
}
|
||||
|
||||
func GetDownloadHeartbeatInterval() time.Duration {
|
||||
return downloadHeartbeatDelay
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
package download_token
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
"log/slog"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DownloadTokenBackgroundService struct {
|
||||
downloadTokenService *DownloadTokenService
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *DownloadTokenBackgroundService) Run() {
|
||||
s.logger.Info("Starting download token cleanup background service")
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
if config.IsShouldShutdown() {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
|
||||
s.logger.Error("Failed to clean expired download tokens", "error", err)
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Minute)
|
||||
}
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package download_token
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var downloadTokenRepository = &DownloadTokenRepository{}
|
||||
|
||||
var downloadTokenService = &DownloadTokenService{
|
||||
downloadTokenRepository,
|
||||
logger.GetLogger(),
|
||||
}
|
||||
|
||||
var downloadTokenBackgroundService = &DownloadTokenBackgroundService{
|
||||
downloadTokenService,
|
||||
logger.GetLogger(),
|
||||
}
|
||||
|
||||
func GetDownloadTokenService() *DownloadTokenService {
|
||||
return downloadTokenService
|
||||
}
|
||||
|
||||
func GetDownloadTokenBackgroundService() *DownloadTokenBackgroundService {
|
||||
return downloadTokenBackgroundService
|
||||
}
|
||||
@@ -1,69 +0,0 @@
|
||||
package download_token
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type DownloadTokenService struct {
|
||||
repository *DownloadTokenRepository
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) Generate(backupID, userID uuid.UUID) (string, error) {
|
||||
token := GenerateSecureToken()
|
||||
|
||||
downloadToken := &DownloadToken{
|
||||
Token: token,
|
||||
BackupID: backupID,
|
||||
UserID: userID,
|
||||
ExpiresAt: time.Now().UTC().Add(5 * time.Minute),
|
||||
Used: false,
|
||||
}
|
||||
|
||||
if err := s.repository.Create(downloadToken); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
s.logger.Info("Generated download token", "backupId", backupID, "userId", userID)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) ValidateAndConsume(token string) (*DownloadToken, error) {
|
||||
dt, err := s.repository.FindByToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if dt == nil {
|
||||
return nil, errors.New("invalid token")
|
||||
}
|
||||
|
||||
if dt.Used {
|
||||
return nil, errors.New("token already used")
|
||||
}
|
||||
|
||||
if time.Now().UTC().After(dt.ExpiresAt) {
|
||||
return nil, errors.New("token expired")
|
||||
}
|
||||
|
||||
dt.Used = true
|
||||
if err := s.repository.Update(dt); err != nil {
|
||||
s.logger.Error("Failed to mark token as used", "error", err)
|
||||
}
|
||||
|
||||
s.logger.Info("Token validated and consumed", "backupId", dt.BackupID)
|
||||
return dt, nil
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) CleanExpiredTokens() error {
|
||||
now := time.Now().UTC()
|
||||
if err := s.repository.DeleteExpired(now); err != nil {
|
||||
return err
|
||||
}
|
||||
s.logger.Debug("Cleaned expired download tokens")
|
||||
return nil
|
||||
}
|
||||
@@ -1,10 +1,9 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/encryption"
|
||||
"io"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type GetBackupsRequest struct {
|
||||
@@ -14,23 +13,17 @@ type GetBackupsRequest struct {
|
||||
}
|
||||
|
||||
type GetBackupsResponse struct {
|
||||
Backups []*Backup `json:"backups"`
|
||||
Total int64 `json:"total"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
Backups []*backups_core.Backup `json:"backups"`
|
||||
Total int64 `json:"total"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
}
|
||||
|
||||
type GenerateDownloadTokenResponse struct {
|
||||
Token string `json:"token"`
|
||||
Filename string `json:"filename"`
|
||||
BackupID uuid.UUID `json:"backupId"`
|
||||
}
|
||||
|
||||
type decryptionReaderCloser struct {
|
||||
type DecryptionReaderCloser struct {
|
||||
*encryption.DecryptionReader
|
||||
baseReader io.ReadCloser
|
||||
BaseReader io.ReadCloser
|
||||
}
|
||||
|
||||
func (r *decryptionReaderCloser) Close() error {
|
||||
return r.baseReader.Close()
|
||||
func (r *DecryptionReaderCloser) Close() error {
|
||||
return r.BaseReader.Close()
|
||||
}
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type MockNotificationSender struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockNotificationSender) SendNotification(
|
||||
notifier *notifiers.Notifier,
|
||||
title string,
|
||||
message string,
|
||||
) {
|
||||
m.Called(notifier, title, message)
|
||||
}
|
||||
@@ -1,27 +1,27 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups/download_token"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
"databasus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
users_models "databasus-backend/internal/features/users/models"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
files_utils "databasus-backend/internal/util/files"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -29,26 +29,28 @@ import (
|
||||
type BackupService struct {
|
||||
databaseService *databases.DatabaseService
|
||||
storageService *storages.StorageService
|
||||
backupRepository *BackupRepository
|
||||
backupRepository *backups_core.BackupRepository
|
||||
notifierService *notifiers.NotifierService
|
||||
notificationSender NotificationSender
|
||||
notificationSender backups_core.NotificationSender
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
secretKeyService *encryption_secrets.SecretKeyService
|
||||
fieldEncryptor util_encryption.FieldEncryptor
|
||||
|
||||
createBackupUseCase CreateBackupUsecase
|
||||
createBackupUseCase backups_core.CreateBackupUsecase
|
||||
|
||||
logger *slog.Logger
|
||||
|
||||
backupRemoveListeners []BackupRemoveListener
|
||||
backupRemoveListeners []backups_core.BackupRemoveListener
|
||||
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
auditLogService *audit_logs.AuditLogService
|
||||
backupContextManager *BackupContextManager
|
||||
downloadTokenService *download_token.DownloadTokenService
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
auditLogService *audit_logs.AuditLogService
|
||||
taskCancelManager *task_cancellation.TaskCancelManager
|
||||
downloadTokenService *backups_download.DownloadTokenService
|
||||
backupSchedulerService *backuping.BackupsScheduler
|
||||
backupCleaner *backuping.BackupCleaner
|
||||
}
|
||||
|
||||
func (s *BackupService) AddBackupRemoveListener(listener BackupRemoveListener) {
|
||||
func (s *BackupService) AddBackupRemoveListener(listener backups_core.BackupRemoveListener) {
|
||||
s.backupRemoveListeners = append(s.backupRemoveListeners, listener)
|
||||
}
|
||||
|
||||
@@ -91,7 +93,7 @@ func (s *BackupService) MakeBackupWithAuth(
|
||||
return errors.New("insufficient permissions to create backup for this database")
|
||||
}
|
||||
|
||||
go s.MakeBackup(databaseID, true)
|
||||
s.backupSchedulerService.StartBackup(database, true)
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Backup manually initiated for database: %s", database.Name),
|
||||
@@ -175,277 +177,20 @@ func (s *BackupService) DeleteBackup(
|
||||
return errors.New("insufficient permissions to delete backup for this database")
|
||||
}
|
||||
|
||||
if backup.Status == BackupStatusInProgress {
|
||||
if backup.Status == backups_core.BackupStatusInProgress {
|
||||
return errors.New("backup is in progress")
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
"Backup deleted for database: %s (ID: %s)",
|
||||
database.Name,
|
||||
backupID.String(),
|
||||
),
|
||||
fmt.Sprintf("Backup deleted for database: %s", database.Name),
|
||||
&user.ID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
|
||||
return s.deleteBackup(backup)
|
||||
return s.backupCleaner.DeleteBackup(backup)
|
||||
}
|
||||
|
||||
func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
|
||||
database, err := s.databaseService.GetDatabaseByID(databaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get database by ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
lastBackup, err := s.backupRepository.FindLastByDatabaseID(databaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to find last backup by database ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if lastBackup != nil && lastBackup.Status == BackupStatusInProgress {
|
||||
s.logger.Error("Backup is in progress")
|
||||
return
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(databaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if backupConfig.StorageID == nil {
|
||||
s.logger.Error("Backup config storage ID is not defined")
|
||||
return
|
||||
}
|
||||
|
||||
storage, err := s.storageService.GetStorageByID(*backupConfig.StorageID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get storage by ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
backup := &Backup{
|
||||
DatabaseID: databaseID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusInProgress,
|
||||
|
||||
BackupSizeMb: 0,
|
||||
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error("Failed to save backup", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now().UTC()
|
||||
|
||||
backupProgressListener := func(
|
||||
completedMBs float64,
|
||||
) {
|
||||
backup.BackupSizeMb = completedMBs
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error("Failed to update backup progress", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s.backupContextManager.RegisterBackup(backup.ID, cancel)
|
||||
defer s.backupContextManager.UnregisterBackup(backup.ID)
|
||||
|
||||
backupMetadata, err := s.createBackupUseCase.Execute(
|
||||
ctx,
|
||||
backup.ID,
|
||||
backupConfig,
|
||||
database,
|
||||
storage,
|
||||
backupProgressListener,
|
||||
)
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
|
||||
// Check if backup was cancelled (not due to shutdown)
|
||||
isCancelled := strings.Contains(errMsg, "backup cancelled") ||
|
||||
strings.Contains(errMsg, "context canceled") ||
|
||||
errors.Is(err, context.Canceled)
|
||||
isShutdown := strings.Contains(errMsg, "shutdown")
|
||||
|
||||
if isCancelled && !isShutdown {
|
||||
backup.Status = BackupStatusCanceled
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error("Failed to save cancelled backup", "error", err)
|
||||
}
|
||||
|
||||
// Delete partial backup from storage
|
||||
storage, storageErr := s.storageService.GetStorageByID(backup.StorageID)
|
||||
if storageErr == nil {
|
||||
if deleteErr := storage.DeleteFile(s.fieldEncryptor, backup.ID); deleteErr != nil {
|
||||
s.logger.Error(
|
||||
"Failed to delete partial backup file",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
deleteErr,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
backup.FailMessage = &errMsg
|
||||
backup.Status = BackupStatusFailed
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
if updateErr := s.databaseService.SetBackupError(databaseID, errMsg); updateErr != nil {
|
||||
s.logger.Error(
|
||||
"Failed to update database last backup time",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"error",
|
||||
updateErr,
|
||||
)
|
||||
}
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error("Failed to save backup", "error", err)
|
||||
}
|
||||
|
||||
s.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupFailed,
|
||||
&errMsg,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
backup.Status = BackupStatusCompleted
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
// Update backup with encryption metadata if provided
|
||||
if backupMetadata != nil {
|
||||
backup.EncryptionSalt = backupMetadata.EncryptionSalt
|
||||
backup.EncryptionIV = backupMetadata.EncryptionIV
|
||||
backup.Encryption = backupMetadata.Encryption
|
||||
}
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error("Failed to save backup", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Update database last backup time
|
||||
now := time.Now().UTC()
|
||||
if updateErr := s.databaseService.SetLastBackupTime(databaseID, now); updateErr != nil {
|
||||
s.logger.Error(
|
||||
"Failed to update database last backup time",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"error",
|
||||
updateErr,
|
||||
)
|
||||
}
|
||||
|
||||
if backup.Status != BackupStatusCompleted && !isLastTry {
|
||||
return
|
||||
}
|
||||
|
||||
s.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupSuccess,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
func (s *BackupService) SendBackupNotification(
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
backup *Backup,
|
||||
notificationType backups_config.BackupNotificationType,
|
||||
errorMessage *string,
|
||||
) {
|
||||
database, err := s.databaseService.GetDatabaseByID(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
workspace, err := s.workspaceService.GetWorkspaceByID(*database.WorkspaceID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, notifier := range database.Notifiers {
|
||||
if !slices.Contains(
|
||||
backupConfig.SendNotificationsOn,
|
||||
notificationType,
|
||||
) {
|
||||
continue
|
||||
}
|
||||
|
||||
title := ""
|
||||
switch notificationType {
|
||||
case backups_config.NotificationBackupFailed:
|
||||
title = fmt.Sprintf(
|
||||
"❌ Backup failed for database \"%s\" (workspace \"%s\")",
|
||||
database.Name,
|
||||
workspace.Name,
|
||||
)
|
||||
case backups_config.NotificationBackupSuccess:
|
||||
title = fmt.Sprintf(
|
||||
"✅ Backup completed for database \"%s\" (workspace \"%s\")",
|
||||
database.Name,
|
||||
workspace.Name,
|
||||
)
|
||||
}
|
||||
|
||||
message := ""
|
||||
if errorMessage != nil {
|
||||
message = *errorMessage
|
||||
} else {
|
||||
// Format size conditionally
|
||||
var sizeStr string
|
||||
if backup.BackupSizeMb < 1024 {
|
||||
sizeStr = fmt.Sprintf("%.2f MB", backup.BackupSizeMb)
|
||||
} else {
|
||||
sizeGB := backup.BackupSizeMb / 1024
|
||||
sizeStr = fmt.Sprintf("%.2f GB", sizeGB)
|
||||
}
|
||||
|
||||
// Format duration as "0m 0s 0ms"
|
||||
totalMs := backup.BackupDurationMs
|
||||
minutes := totalMs / (1000 * 60)
|
||||
seconds := (totalMs % (1000 * 60)) / 1000
|
||||
durationStr := fmt.Sprintf("%dm %ds", minutes, seconds)
|
||||
|
||||
message = fmt.Sprintf(
|
||||
"Backup completed successfully in %s.\nCompressed backup size: %s",
|
||||
durationStr,
|
||||
sizeStr,
|
||||
)
|
||||
}
|
||||
|
||||
s.notificationSender.SendNotification(
|
||||
¬ifier,
|
||||
title,
|
||||
message,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupService) GetBackup(backupID uuid.UUID) (*Backup, error) {
|
||||
func (s *BackupService) GetBackup(backupID uuid.UUID) (*backups_core.Backup, error) {
|
||||
return s.backupRepository.FindByID(backupID)
|
||||
}
|
||||
|
||||
@@ -475,20 +220,16 @@ func (s *BackupService) CancelBackup(
|
||||
return errors.New("insufficient permissions to cancel backup for this database")
|
||||
}
|
||||
|
||||
if backup.Status != BackupStatusInProgress {
|
||||
if backup.Status != backups_core.BackupStatusInProgress {
|
||||
return errors.New("backup is not in progress")
|
||||
}
|
||||
|
||||
if err := s.backupContextManager.CancelBackup(backupID); err != nil {
|
||||
if err := s.taskCancelManager.CancelTask(backupID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
"Backup cancelled for database: %s (ID: %s)",
|
||||
database.Name,
|
||||
backupID.String(),
|
||||
),
|
||||
fmt.Sprintf("Backup cancelled for database: %s", database.Name),
|
||||
&user.ID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
@@ -499,7 +240,7 @@ func (s *BackupService) CancelBackup(
|
||||
func (s *BackupService) GetBackupFile(
|
||||
user *users_models.User,
|
||||
backupID uuid.UUID,
|
||||
) (io.ReadCloser, *Backup, *databases.Database, error) {
|
||||
) (io.ReadCloser, *backups_core.Backup, *databases.Database, error) {
|
||||
backup, err := s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
@@ -528,11 +269,7 @@ func (s *BackupService) GetBackupFile(
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
"Backup file downloaded for database: %s (ID: %s)",
|
||||
database.Name,
|
||||
backupID.String(),
|
||||
),
|
||||
fmt.Sprintf("Backup file downloaded for database: %s", database.Name),
|
||||
&user.ID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
@@ -545,33 +282,10 @@ func (s *BackupService) GetBackupFile(
|
||||
return reader, backup, database, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) deleteBackup(backup *Backup) error {
|
||||
for _, listener := range s.backupRemoveListeners {
|
||||
if err := listener.OnBeforeBackupRemove(backup); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
storage, err := s.storageService.GetStorageByID(backup.StorageID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = storage.DeleteFile(s.fieldEncryptor, backup.ID)
|
||||
if err != nil {
|
||||
// we do not return error here, because sometimes clean up performed
|
||||
// before unavailable storage removal or change - therefore we should
|
||||
// proceed even in case of error
|
||||
s.logger.Error("Failed to delete backup file", "error", err)
|
||||
}
|
||||
|
||||
return s.backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
|
||||
dbBackupsInProgress, err := s.backupRepository.FindByDatabaseIdAndStatus(
|
||||
databaseID,
|
||||
BackupStatusInProgress,
|
||||
backups_core.BackupStatusInProgress,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -589,7 +303,7 @@ func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
|
||||
}
|
||||
|
||||
for _, dbBackup := range dbBackups {
|
||||
err := s.deleteBackup(dbBackup)
|
||||
err := s.backupCleaner.DeleteBackup(dbBackup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -611,7 +325,7 @@ func (s *BackupService) getBackupReader(backupID uuid.UUID) (io.ReadCloser, erro
|
||||
return nil, fmt.Errorf("failed to get storage: %w", err)
|
||||
}
|
||||
|
||||
fileReader, err := storage.GetFile(s.fieldEncryptor, backup.ID)
|
||||
fileReader, err := storage.GetFile(s.fieldEncryptor, backup.ID.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get backup file: %w", err)
|
||||
}
|
||||
@@ -680,16 +394,16 @@ func (s *BackupService) getBackupReader(backupID uuid.UUID) (io.ReadCloser, erro
|
||||
|
||||
s.logger.Info("Returning encrypted backup with decryption", "backupId", backupID)
|
||||
|
||||
return &decryptionReaderCloser{
|
||||
decryptionReader,
|
||||
fileReader,
|
||||
return &DecryptionReaderCloser{
|
||||
DecryptionReader: decryptionReader,
|
||||
BaseReader: fileReader,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) GenerateDownloadToken(
|
||||
user *users_models.User,
|
||||
backupID uuid.UUID,
|
||||
) (*GenerateDownloadTokenResponse, error) {
|
||||
) (*backups_download.GenerateDownloadTokenResponse, error) {
|
||||
backup, err := s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -725,20 +439,22 @@ func (s *BackupService) GenerateDownloadToken(
|
||||
database.WorkspaceID,
|
||||
)
|
||||
|
||||
return &GenerateDownloadTokenResponse{
|
||||
return &backups_download.GenerateDownloadTokenResponse{
|
||||
Token: token,
|
||||
Filename: filename,
|
||||
BackupID: backupID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) ValidateDownloadToken(token string) (*download_token.DownloadToken, error) {
|
||||
func (s *BackupService) ValidateDownloadToken(
|
||||
token string,
|
||||
) (*backups_download.DownloadToken, *backups_download.RateLimiter, error) {
|
||||
return s.downloadTokenService.ValidateAndConsume(token)
|
||||
}
|
||||
|
||||
func (s *BackupService) GetBackupFileWithoutAuth(
|
||||
backupID uuid.UUID,
|
||||
) (io.ReadCloser, *Backup, *databases.Database, error) {
|
||||
) (io.ReadCloser, *backups_core.Backup, *databases.Database, error) {
|
||||
backup, err := s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
@@ -759,26 +475,38 @@ func (s *BackupService) GetBackupFileWithoutAuth(
|
||||
|
||||
func (s *BackupService) WriteAuditLogForDownload(
|
||||
userID uuid.UUID,
|
||||
backup *Backup,
|
||||
backup *backups_core.Backup,
|
||||
database *databases.Database,
|
||||
) {
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
"Backup file downloaded for database: %s (ID: %s)",
|
||||
database.Name,
|
||||
backup.ID.String(),
|
||||
),
|
||||
fmt.Sprintf("Backup file downloaded for database: %s", database.Name),
|
||||
&userID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
}
|
||||
|
||||
func (s *BackupService) RefreshDownloadLock(userID uuid.UUID) {
|
||||
s.downloadTokenService.RefreshDownloadLock(userID)
|
||||
}
|
||||
|
||||
func (s *BackupService) ReleaseDownloadLock(userID uuid.UUID) {
|
||||
s.downloadTokenService.ReleaseDownloadLock(userID)
|
||||
}
|
||||
|
||||
func (s *BackupService) IsDownloadInProgress(userID uuid.UUID) bool {
|
||||
return s.downloadTokenService.IsDownloadInProgress(userID)
|
||||
}
|
||||
|
||||
func (s *BackupService) UnregisterDownload(userID uuid.UUID) {
|
||||
s.downloadTokenService.UnregisterDownload(userID)
|
||||
}
|
||||
|
||||
func (s *BackupService) generateBackupFilename(
|
||||
backup *Backup,
|
||||
backup *backups_core.Backup,
|
||||
database *databases.Database,
|
||||
) string {
|
||||
timestamp := backup.CreatedAt.Format("2006-01-02_15-04-05")
|
||||
safeName := sanitizeFilename(database.Name)
|
||||
safeName := files_utils.SanitizeFilename(database.Name)
|
||||
extension := s.getBackupExtension(database.Type)
|
||||
return fmt.Sprintf("%s_backup_%s%s", safeName, timestamp, extension)
|
||||
}
|
||||
|
||||
@@ -1,206 +0,0 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
common "databasus-backend/internal/features/backups/backups/common"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
t.Run("BackupFailed_FailNotificationSent", func(t *testing.T) {
|
||||
mockNotificationSender := &MockNotificationSender{}
|
||||
backupService := &BackupService{
|
||||
databases.GetDatabaseService(),
|
||||
storages.GetStorageService(),
|
||||
backupRepository,
|
||||
notifiers.GetNotifierService(),
|
||||
mockNotificationSender,
|
||||
backups_config.GetBackupConfigService(),
|
||||
encryption_secrets.GetSecretKeyService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
&CreateFailedBackupUsecase{},
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
nil,
|
||||
NewBackupContextManager(),
|
||||
nil,
|
||||
}
|
||||
|
||||
// Set up expectations
|
||||
mockNotificationSender.On("SendNotification",
|
||||
mock.Anything,
|
||||
mock.MatchedBy(func(title string) bool {
|
||||
return strings.Contains(title, "❌ Backup failed")
|
||||
}),
|
||||
mock.MatchedBy(func(message string) bool {
|
||||
return strings.Contains(message, "backup failed")
|
||||
}),
|
||||
).Once()
|
||||
|
||||
backupService.MakeBackup(database.ID, true)
|
||||
|
||||
// Verify all expectations were met
|
||||
mockNotificationSender.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("BackupSuccess_SuccessNotificationSent", func(t *testing.T) {
|
||||
mockNotificationSender := &MockNotificationSender{}
|
||||
|
||||
// Set up expectations
|
||||
mockNotificationSender.On("SendNotification",
|
||||
mock.Anything,
|
||||
mock.MatchedBy(func(title string) bool {
|
||||
return strings.Contains(title, "✅ Backup completed")
|
||||
}),
|
||||
mock.MatchedBy(func(message string) bool {
|
||||
return strings.Contains(message, "Backup completed successfully")
|
||||
}),
|
||||
).Once()
|
||||
|
||||
backupService := &BackupService{
|
||||
databases.GetDatabaseService(),
|
||||
storages.GetStorageService(),
|
||||
backupRepository,
|
||||
notifiers.GetNotifierService(),
|
||||
mockNotificationSender,
|
||||
backups_config.GetBackupConfigService(),
|
||||
encryption_secrets.GetSecretKeyService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
&CreateSuccessBackupUsecase{},
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
nil,
|
||||
NewBackupContextManager(),
|
||||
nil,
|
||||
}
|
||||
|
||||
backupService.MakeBackup(database.ID, true)
|
||||
|
||||
// Verify all expectations were met
|
||||
mockNotificationSender.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("BackupSuccess_VerifyNotificationContent", func(t *testing.T) {
|
||||
mockNotificationSender := &MockNotificationSender{}
|
||||
backupService := &BackupService{
|
||||
databases.GetDatabaseService(),
|
||||
storages.GetStorageService(),
|
||||
backupRepository,
|
||||
notifiers.GetNotifierService(),
|
||||
mockNotificationSender,
|
||||
backups_config.GetBackupConfigService(),
|
||||
encryption_secrets.GetSecretKeyService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
&CreateSuccessBackupUsecase{},
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
nil,
|
||||
NewBackupContextManager(),
|
||||
nil,
|
||||
}
|
||||
|
||||
// capture arguments
|
||||
var capturedNotifier *notifiers.Notifier
|
||||
var capturedTitle string
|
||||
var capturedMessage string
|
||||
|
||||
mockNotificationSender.On("SendNotification",
|
||||
mock.Anything,
|
||||
mock.AnythingOfType("string"),
|
||||
mock.AnythingOfType("string"),
|
||||
).Run(func(args mock.Arguments) {
|
||||
capturedNotifier = args.Get(0).(*notifiers.Notifier)
|
||||
capturedTitle = args.Get(1).(string)
|
||||
capturedMessage = args.Get(2).(string)
|
||||
}).Once()
|
||||
|
||||
backupService.MakeBackup(database.ID, true)
|
||||
|
||||
// Verify expectations were met
|
||||
mockNotificationSender.AssertExpectations(t)
|
||||
|
||||
// Additional detailed assertions
|
||||
assert.Contains(t, capturedTitle, "✅ Backup completed")
|
||||
assert.Contains(t, capturedTitle, database.Name)
|
||||
assert.Contains(t, capturedMessage, "Backup completed successfully")
|
||||
assert.Contains(t, capturedMessage, "10.00 MB")
|
||||
assert.Equal(t, notifier.ID, capturedNotifier.ID)
|
||||
})
|
||||
}
|
||||
|
||||
type CreateFailedBackupUsecase struct {
|
||||
}
|
||||
|
||||
func (uc *CreateFailedBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
backupProgressListener(10)
|
||||
return nil, errors.New("backup failed")
|
||||
}
|
||||
|
||||
type CreateSuccessBackupUsecase struct{}
|
||||
|
||||
func (uc *CreateSuccessBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
backupProgressListener(10)
|
||||
return &common.BackupMetadata{
|
||||
EncryptionSalt: nil,
|
||||
EncryptionIV: nil,
|
||||
Encryption: backups_config.BackupEncryptionNone,
|
||||
}, nil
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
@@ -58,9 +59,9 @@ func WaitForBackupCompletion(
|
||||
newestBackup := backups[0]
|
||||
t.Logf("WaitForBackupCompletion: newest backup status: %s", newestBackup.Status)
|
||||
|
||||
if newestBackup.Status == BackupStatusCompleted ||
|
||||
newestBackup.Status == BackupStatusFailed ||
|
||||
newestBackup.Status == BackupStatusCanceled {
|
||||
if newestBackup.Status == backups_core.BackupStatusCompleted ||
|
||||
newestBackup.Status == backups_core.BackupStatusFailed ||
|
||||
newestBackup.Status == backups_core.BackupStatusCanceled {
|
||||
t.Logf(
|
||||
"WaitForBackupCompletion: backup finished with status %s",
|
||||
newestBackup.Status,
|
||||
@@ -74,3 +75,23 @@ func WaitForBackupCompletion(
|
||||
|
||||
t.Logf("WaitForBackupCompletion: timeout waiting for backup to complete")
|
||||
}
|
||||
|
||||
// CreateTestBackup creates a simple test backup record for testing purposes
|
||||
func CreateTestBackup(databaseID, storageID uuid.UUID) *backups_core.Backup {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: databaseID,
|
||||
StorageID: storageID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10.5,
|
||||
BackupDurationMs: 1000,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
repo := &backups_core.BackupRepository{}
|
||||
if err := repo.Save(backup); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return backup
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
|
||||
common "databasus-backend/internal/features/backups/backups/common"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
usecases_mariadb "databasus-backend/internal/features/backups/backups/usecases/mariadb"
|
||||
usecases_mongodb "databasus-backend/internal/features/backups/backups/usecases/mongodb"
|
||||
usecases_mysql "databasus-backend/internal/features/backups/backups/usecases/mysql"
|
||||
@@ -12,8 +13,6 @@ import (
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/storages"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type CreateBackupUsecase struct {
|
||||
@@ -25,7 +24,7 @@ type CreateBackupUsecase struct {
|
||||
|
||||
func (uc *CreateBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
@@ -35,7 +34,7 @@ func (uc *CreateBackupUsecase) Execute(
|
||||
case databases.DatabaseTypePostgres:
|
||||
return uc.CreatePostgresqlBackupUsecase.Execute(
|
||||
ctx,
|
||||
backupID,
|
||||
backup,
|
||||
backupConfig,
|
||||
database,
|
||||
storage,
|
||||
@@ -45,7 +44,7 @@ func (uc *CreateBackupUsecase) Execute(
|
||||
case databases.DatabaseTypeMysql:
|
||||
return uc.CreateMysqlBackupUsecase.Execute(
|
||||
ctx,
|
||||
backupID,
|
||||
backup,
|
||||
backupConfig,
|
||||
database,
|
||||
storage,
|
||||
@@ -55,7 +54,7 @@ func (uc *CreateBackupUsecase) Execute(
|
||||
case databases.DatabaseTypeMariadb:
|
||||
return uc.CreateMariadbBackupUsecase.Execute(
|
||||
ctx,
|
||||
backupID,
|
||||
backup,
|
||||
backupConfig,
|
||||
database,
|
||||
storage,
|
||||
@@ -65,7 +64,7 @@ func (uc *CreateBackupUsecase) Execute(
|
||||
case databases.DatabaseTypeMongodb:
|
||||
return uc.CreateMongodbBackupUsecase.Execute(
|
||||
ctx,
|
||||
backupID,
|
||||
backup,
|
||||
backupConfig,
|
||||
database,
|
||||
storage,
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
common "databasus-backend/internal/features/backups/backups/common"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -52,7 +53,7 @@ type writeResult struct {
|
||||
|
||||
func (uc *CreateMariadbBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
db *databases.Database,
|
||||
storage *storages.Storage,
|
||||
@@ -82,7 +83,7 @@ func (uc *CreateMariadbBackupUsecase) Execute(
|
||||
|
||||
return uc.streamToStorage(
|
||||
ctx,
|
||||
backupID,
|
||||
backup,
|
||||
backupConfig,
|
||||
tools.GetMariadbExecutable(
|
||||
tools.MariadbExecutableMariadbDump,
|
||||
@@ -108,13 +109,15 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpArgs(
|
||||
"--single-transaction",
|
||||
"--routines",
|
||||
"--quick",
|
||||
"--skip-extended-insert",
|
||||
"--verbose",
|
||||
}
|
||||
|
||||
if mdb.HasPrivilege("TRIGGER") {
|
||||
args = append(args, "--triggers")
|
||||
}
|
||||
if mdb.HasPrivilege("EVENT") {
|
||||
|
||||
if mdb.HasPrivilege("EVENT") && !mdb.IsExcludeEvents {
|
||||
args = append(args, "--events")
|
||||
}
|
||||
|
||||
@@ -134,7 +137,7 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpArgs(
|
||||
|
||||
func (uc *CreateMariadbBackupUsecase) streamToStorage(
|
||||
parentCtx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
mariadbBin string,
|
||||
args []string,
|
||||
@@ -185,7 +188,7 @@ func (uc *CreateMariadbBackupUsecase) streamToStorage(
|
||||
storageReader, storageWriter := io.Pipe()
|
||||
|
||||
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
|
||||
backupID,
|
||||
backup.ID,
|
||||
backupConfig,
|
||||
storageWriter,
|
||||
)
|
||||
@@ -202,7 +205,13 @@ func (uc *CreateMariadbBackupUsecase) streamToStorage(
|
||||
|
||||
saveErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
saveErr := storage.SaveFile(ctx, uc.fieldEncryptor, uc.logger, backupID, storageReader)
|
||||
saveErr := storage.SaveFile(
|
||||
ctx,
|
||||
uc.fieldEncryptor,
|
||||
uc.logger,
|
||||
backup.FileName,
|
||||
storageReader,
|
||||
)
|
||||
saveErrCh <- saveErr
|
||||
}()
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
common "databasus-backend/internal/features/backups/backups/common"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -46,7 +47,7 @@ type writeResult struct {
|
||||
|
||||
func (uc *CreateMongodbBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
db *databases.Database,
|
||||
storage *storages.Storage,
|
||||
@@ -76,7 +77,7 @@ func (uc *CreateMongodbBackupUsecase) Execute(
|
||||
|
||||
return uc.streamToStorage(
|
||||
ctx,
|
||||
backupID,
|
||||
backup,
|
||||
backupConfig,
|
||||
tools.GetMongodbExecutable(
|
||||
tools.MongodbExecutableMongodump,
|
||||
@@ -114,7 +115,7 @@ func (uc *CreateMongodbBackupUsecase) buildMongodumpArgs(
|
||||
|
||||
func (uc *CreateMongodbBackupUsecase) streamToStorage(
|
||||
parentCtx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
mongodumpBin string,
|
||||
args []string,
|
||||
@@ -163,7 +164,7 @@ func (uc *CreateMongodbBackupUsecase) streamToStorage(
|
||||
storageReader, storageWriter := io.Pipe()
|
||||
|
||||
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
|
||||
backupID,
|
||||
backup.ID,
|
||||
backupConfig,
|
||||
storageWriter,
|
||||
)
|
||||
@@ -175,7 +176,13 @@ func (uc *CreateMongodbBackupUsecase) streamToStorage(
|
||||
|
||||
saveErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
saveErr := storage.SaveFile(ctx, uc.fieldEncryptor, uc.logger, backupID, storageReader)
|
||||
saveErr := storage.SaveFile(
|
||||
ctx,
|
||||
uc.fieldEncryptor,
|
||||
uc.logger,
|
||||
backup.FileName,
|
||||
storageReader,
|
||||
)
|
||||
saveErrCh <- saveErr
|
||||
}()
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
common "databasus-backend/internal/features/backups/backups/common"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -52,7 +53,7 @@ type writeResult struct {
|
||||
|
||||
func (uc *CreateMysqlBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
db *databases.Database,
|
||||
storage *storages.Storage,
|
||||
@@ -82,7 +83,7 @@ func (uc *CreateMysqlBackupUsecase) Execute(
|
||||
|
||||
return uc.streamToStorage(
|
||||
ctx,
|
||||
backupID,
|
||||
backup,
|
||||
backupConfig,
|
||||
tools.GetMysqlExecutable(
|
||||
my.Version,
|
||||
@@ -107,6 +108,7 @@ func (uc *CreateMysqlBackupUsecase) buildMysqldumpArgs(my *mysqltypes.MysqlDatab
|
||||
"--routines",
|
||||
"--set-gtid-purged=OFF",
|
||||
"--quick",
|
||||
"--skip-extended-insert",
|
||||
"--verbose",
|
||||
}
|
||||
|
||||
@@ -148,7 +150,7 @@ func (uc *CreateMysqlBackupUsecase) getNetworkCompressionArgs(version tools.Mysq
|
||||
|
||||
func (uc *CreateMysqlBackupUsecase) streamToStorage(
|
||||
parentCtx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
mysqlBin string,
|
||||
args []string,
|
||||
@@ -199,7 +201,7 @@ func (uc *CreateMysqlBackupUsecase) streamToStorage(
|
||||
storageReader, storageWriter := io.Pipe()
|
||||
|
||||
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
|
||||
backupID,
|
||||
backup.ID,
|
||||
backupConfig,
|
||||
storageWriter,
|
||||
)
|
||||
@@ -216,7 +218,13 @@ func (uc *CreateMysqlBackupUsecase) streamToStorage(
|
||||
|
||||
saveErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
saveErr := storage.SaveFile(ctx, uc.fieldEncryptor, uc.logger, backupID, storageReader)
|
||||
saveErr := storage.SaveFile(
|
||||
ctx,
|
||||
uc.fieldEncryptor,
|
||||
uc.logger,
|
||||
backup.FileName,
|
||||
storageReader,
|
||||
)
|
||||
saveErrCh <- saveErr
|
||||
}()
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
common "databasus-backend/internal/features/backups/backups/common"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -53,7 +54,7 @@ type writeResult struct {
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
db *databases.Database,
|
||||
storage *storages.Storage,
|
||||
@@ -88,7 +89,7 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
|
||||
|
||||
return uc.streamToStorage(
|
||||
ctx,
|
||||
backupID,
|
||||
backup,
|
||||
backupConfig,
|
||||
tools.GetPostgresqlExecutable(
|
||||
pg.Version,
|
||||
@@ -107,7 +108,7 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
|
||||
// streamToStorage streams pg_dump output directly to storage
|
||||
func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
|
||||
parentCtx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
pgBin string,
|
||||
args []string,
|
||||
@@ -166,7 +167,7 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
|
||||
storageReader, storageWriter := io.Pipe()
|
||||
|
||||
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
|
||||
backupID,
|
||||
backup.ID,
|
||||
backupConfig,
|
||||
storageWriter,
|
||||
)
|
||||
@@ -181,7 +182,13 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
|
||||
// Start streaming into storage in its own goroutine
|
||||
saveErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
saveErr := storage.SaveFile(ctx, uc.fieldEncryptor, uc.logger, backupID, storageReader)
|
||||
saveErr := storage.SaveFile(
|
||||
ctx,
|
||||
uc.fieldEncryptor,
|
||||
uc.logger,
|
||||
backup.FileName,
|
||||
storageReader,
|
||||
)
|
||||
saveErrCh <- saveErr
|
||||
}()
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ type BackupConfigController struct {
|
||||
|
||||
func (c *BackupConfigController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
router.POST("/backup-configs/save", c.SaveBackupConfig)
|
||||
router.GET("/backup-configs/database/:id/plan", c.GetDatabasePlan)
|
||||
router.GET("/backup-configs/database/:id", c.GetBackupConfigByDbID)
|
||||
router.GET("/backup-configs/storage/:id/is-using", c.IsStorageUsing)
|
||||
router.GET("/backup-configs/storage/:id/databases-count", c.CountDatabasesForStorage)
|
||||
@@ -92,6 +93,39 @@ func (c *BackupConfigController) GetBackupConfigByDbID(ctx *gin.Context) {
|
||||
ctx.JSON(http.StatusOK, backupConfig)
|
||||
}
|
||||
|
||||
// GetDatabasePlan
|
||||
// @Summary Get database plan by database ID
|
||||
// @Description Get the plan limits for a specific database (max backup size, max total size, max storage period)
|
||||
// @Tags backup-configs
|
||||
// @Produce json
|
||||
// @Param id path string true "Database ID"
|
||||
// @Success 200 {object} plans.DatabasePlan
|
||||
// @Failure 400 {object} map[string]string "Invalid database ID"
|
||||
// @Failure 401 {object} map[string]string "User not authenticated"
|
||||
// @Failure 404 {object} map[string]string "Database not found or access denied"
|
||||
// @Router /backup-configs/database/{id}/plan [get]
|
||||
func (c *BackupConfigController) GetDatabasePlan(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
|
||||
return
|
||||
}
|
||||
|
||||
plan, err := c.backupConfigService.GetDatabasePlan(user, id)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusNotFound, gin.H{"error": "database plan not found"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, plan)
|
||||
}
|
||||
|
||||
// IsStorageUsing
|
||||
// @Summary Check if storage is being used
|
||||
// @Description Check if a storage is currently being used by any backup configuration
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -16,11 +17,14 @@ import (
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
plans "databasus-backend/internal/features/plan"
|
||||
"databasus-backend/internal/features/storages"
|
||||
local_storage "databasus-backend/internal/features/storages/models/local"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/storage"
|
||||
"databasus-backend/internal/util/period"
|
||||
test_utils "databasus-backend/internal/util/testing"
|
||||
"databasus-backend/internal/util/tools"
|
||||
@@ -89,6 +93,11 @@ func Test_SaveBackupConfig_PermissionsEnforced(t *testing.T) {
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
@@ -152,6 +161,11 @@ func Test_SaveBackupConfig_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *test
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
@@ -242,6 +256,11 @@ func Test_GetBackupConfigByDbID_PermissionsEnforced(t *testing.T) {
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
@@ -290,6 +309,11 @@ func Test_GetBackupConfigByDbID_ReturnsDefaultConfigForNewDatabase(t *testing.T)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
var response BackupConfig
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
@@ -300,14 +324,214 @@ func Test_GetBackupConfigByDbID_ReturnsDefaultConfigForNewDatabase(t *testing.T)
|
||||
&response,
|
||||
)
|
||||
|
||||
var plan plans.DatabasePlan
|
||||
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&plan,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, response.DatabaseID)
|
||||
assert.False(t, response.IsBackupsEnabled)
|
||||
assert.Equal(t, period.PeriodWeek, response.StorePeriod)
|
||||
assert.Equal(t, plan.MaxStoragePeriod, response.StorePeriod)
|
||||
assert.Equal(t, plan.MaxBackupSizeMB, response.MaxBackupSizeMB)
|
||||
assert.Equal(t, plan.MaxBackupsTotalSizeMB, response.MaxBackupsTotalSizeMB)
|
||||
assert.True(t, response.IsRetryIfFailed)
|
||||
assert.Equal(t, 3, response.MaxFailedTriesCount)
|
||||
assert.NotNil(t, response.BackupInterval)
|
||||
}
|
||||
|
||||
func Test_GetDatabasePlan_ForNewDatabase_PlanAlwaysReturned(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
var response plans.DatabasePlan
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, response.DatabaseID)
|
||||
assert.NotNil(t, response.MaxBackupSizeMB)
|
||||
assert.NotNil(t, response.MaxBackupsTotalSizeMB)
|
||||
assert.NotEmpty(t, response.MaxStoragePeriod)
|
||||
}
|
||||
|
||||
func Test_SaveBackupConfig_WhenPlanLimitsAreAdjusted_ValidationEnforced(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Get plan via API (triggers auto-creation)
|
||||
var plan plans.DatabasePlan
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&plan,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, plan.DatabaseID)
|
||||
|
||||
// Adjust plan limits directly in database to fixed restrictive values
|
||||
err := storage.GetDb().Model(&plans.DatabasePlan{}).
|
||||
Where("database_id = ?", database.ID).
|
||||
Updates(map[string]any{
|
||||
"max_backup_size_mb": 100,
|
||||
"max_backups_total_size_mb": 1000,
|
||||
"max_storage_period": period.PeriodMonth,
|
||||
}).Error
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test 1: Try to save backup config with exceeded backup size limit
|
||||
timeOfDay := "04:00"
|
||||
backupConfigExceededSize := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
MaxBackupSizeMB: 200, // Exceeds limit of 100
|
||||
MaxBackupsTotalSizeMB: 800,
|
||||
}
|
||||
|
||||
respExceededSize := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner.Token,
|
||||
backupConfigExceededSize,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
assert.Contains(t, string(respExceededSize.Body), "max backup size exceeds plan limit")
|
||||
|
||||
// Test 2: Try to save backup config with exceeded total size limit
|
||||
backupConfigExceededTotal := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
MaxBackupSizeMB: 50,
|
||||
MaxBackupsTotalSizeMB: 2000, // Exceeds limit of 1000
|
||||
}
|
||||
|
||||
respExceededTotal := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner.Token,
|
||||
backupConfigExceededTotal,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
assert.Contains(t, string(respExceededTotal.Body), "max total backups size exceeds plan limit")
|
||||
|
||||
// Test 3: Try to save backup config with exceeded storage period limit
|
||||
backupConfigExceededPeriod := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodYear, // Exceeds limit of Month
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
MaxBackupSizeMB: 80,
|
||||
MaxBackupsTotalSizeMB: 800,
|
||||
}
|
||||
|
||||
respExceededPeriod := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner.Token,
|
||||
backupConfigExceededPeriod,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
assert.Contains(t, string(respExceededPeriod.Body), "storage period exceeds plan limit")
|
||||
|
||||
// Test 4: Save backup config within all limits - should succeed
|
||||
backupConfigValid := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek, // Within Month limit
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
MaxBackupSizeMB: 80, // Within 100 limit
|
||||
MaxBackupsTotalSizeMB: 800, // Within 1000 limit
|
||||
}
|
||||
|
||||
var responseValid BackupConfig
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner.Token,
|
||||
backupConfigValid,
|
||||
http.StatusOK,
|
||||
&responseValid,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, responseValid.DatabaseID)
|
||||
assert.Equal(t, int64(80), responseValid.MaxBackupSizeMB)
|
||||
assert.Equal(t, int64(800), responseValid.MaxBackupsTotalSizeMB)
|
||||
assert.Equal(t, period.PeriodWeek, responseValid.StorePeriod)
|
||||
}
|
||||
|
||||
func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -340,6 +564,10 @@ func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
|
||||
)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
defer func() {
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
var testUserToken string
|
||||
if tt.isStorageOwner {
|
||||
testUserToken = storageOwner.Token
|
||||
@@ -372,10 +600,6 @@ func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
|
||||
)
|
||||
assert.Contains(t, string(testResp.Body), "error")
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -387,6 +611,11 @@ func Test_SaveBackupConfig_WithEncryptionNone_ConfigSaved(t *testing.T) {
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
timeOfDay := "04:00"
|
||||
request := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
@@ -426,6 +655,11 @@ func Test_SaveBackupConfig_WithEncryptionEncrypted_ConfigSaved(t *testing.T) {
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
timeOfDay := "04:00"
|
||||
request := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
@@ -536,6 +770,15 @@ func Test_TransferDatabase_PermissionsEnforced(t *testing.T) {
|
||||
|
||||
targetStorage := createTestStorage(targetWorkspace.ID)
|
||||
|
||||
defer func() {
|
||||
// Cleanup in correct order to avoid foreign key violations
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond) // Wait for cascade delete of backup_config
|
||||
storages.RemoveTestStorage(targetStorage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
|
||||
}()
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
@@ -628,6 +871,12 @@ func Test_TransferDatabase_NonMemberInSourceWorkspace_CannotTransfer(t *testing.
|
||||
router,
|
||||
)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
|
||||
}()
|
||||
|
||||
request := TransferDatabaseRequest{
|
||||
TargetWorkspaceID: targetWorkspace.ID,
|
||||
}
|
||||
@@ -668,6 +917,12 @@ func Test_TransferDatabase_NonMemberInTargetWorkspace_CannotTransfer(t *testing.
|
||||
router,
|
||||
)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
|
||||
}()
|
||||
|
||||
request := TransferDatabaseRequest{
|
||||
TargetWorkspaceID: targetWorkspace.ID,
|
||||
}
|
||||
@@ -695,6 +950,13 @@ func Test_TransferDatabase_ToNewStorage_DatabaseTransferd(t *testing.T) {
|
||||
sourceStorage := createTestStorage(sourceWorkspace.ID)
|
||||
targetStorage := createTestStorage(targetWorkspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(200 * time.Millisecond) // Wait for cascading deletes
|
||||
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
|
||||
}()
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfigRequest := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
@@ -774,6 +1036,13 @@ func Test_TransferDatabase_WithExistingStorage_DatabaseAndStorageTransferd(t *te
|
||||
database := createTestDatabaseViaAPI("Test Database", sourceWorkspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(sourceWorkspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(200 * time.Millisecond) // Wait for cascading deletes
|
||||
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
|
||||
}()
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfigRequest := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
@@ -863,6 +1132,14 @@ func Test_TransferDatabase_StorageHasOtherDBs_CannotTransfer(t *testing.T) {
|
||||
)
|
||||
storage := createTestStorage(sourceWorkspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database1)
|
||||
databases.RemoveTestDatabase(database2)
|
||||
time.Sleep(200 * time.Millisecond) // Wait for cascading deletes
|
||||
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
|
||||
}()
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfigRequest1 := BackupConfig{
|
||||
DatabaseID: database1.ID,
|
||||
@@ -945,6 +1222,14 @@ func Test_TransferDatabase_WithNotifiers_NotifiersTransferred(t *testing.T) {
|
||||
targetStorage := createTestStorage(targetWorkspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(sourceWorkspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
|
||||
}()
|
||||
|
||||
database.Notifiers = []notifiers.Notifier{*notifier}
|
||||
var updatedDatabase databases.Database
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
@@ -1048,6 +1333,15 @@ func Test_TransferDatabase_NotifierHasOtherDBs_NotifierSkipped(t *testing.T) {
|
||||
targetStorage := createTestStorage(targetWorkspace.ID)
|
||||
sharedNotifier := notifiers.CreateTestNotifier(sourceWorkspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database1)
|
||||
databases.RemoveTestDatabase(database2)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(sharedNotifier)
|
||||
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
|
||||
}()
|
||||
|
||||
database1.Notifiers = []notifiers.Notifier{*sharedNotifier}
|
||||
test_utils.MakePostRequest(
|
||||
t,
|
||||
@@ -1160,6 +1454,16 @@ func Test_TransferDatabase_WithMultipleNotifiers_OnlyExclusiveOnesTransferred(t
|
||||
exclusiveNotifier := notifiers.CreateTestNotifier(sourceWorkspace.ID)
|
||||
sharedNotifier := notifiers.CreateTestNotifier(sourceWorkspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database1)
|
||||
databases.RemoveTestDatabase(database2)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(exclusiveNotifier)
|
||||
notifiers.RemoveTestNotifier(sharedNotifier)
|
||||
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
|
||||
}()
|
||||
|
||||
database1.Notifiers = []notifiers.Notifier{*exclusiveNotifier, *sharedNotifier}
|
||||
test_utils.MakePostRequest(
|
||||
t,
|
||||
@@ -1271,6 +1575,14 @@ func Test_TransferDatabase_WithTargetNotifiers_NotifiersAssigned(t *testing.T) {
|
||||
targetStorage := createTestStorage(targetWorkspace.ID)
|
||||
targetNotifier := notifiers.CreateTestNotifier(targetWorkspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(targetNotifier)
|
||||
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
|
||||
}()
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfigRequest := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
@@ -1342,6 +1654,15 @@ func Test_TransferDatabase_TargetNotifierFromDifferentWorkspace_ReturnsBadReques
|
||||
targetStorage := createTestStorage(targetWorkspace.ID)
|
||||
wrongNotifier := notifiers.CreateTestNotifier(otherWorkspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(wrongNotifier)
|
||||
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(otherWorkspace, router)
|
||||
}()
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfigRequest := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
@@ -1399,6 +1720,14 @@ func Test_TransferDatabase_TargetStorageFromDifferentWorkspace_ReturnsBadRequest
|
||||
sourceStorage := createTestStorage(sourceWorkspace.ID)
|
||||
wrongStorage := createTestStorage(otherWorkspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(otherWorkspace, router)
|
||||
}()
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfigRequest := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
@@ -1443,6 +1772,115 @@ func Test_TransferDatabase_TargetStorageFromDifferentWorkspace_ReturnsBadRequest
|
||||
assert.Contains(t, string(testResp.Body), "target storage does not belong to target workspace")
|
||||
}
|
||||
|
||||
func Test_SaveBackupConfig_WithSystemStorage_CanBeUsedByAnyDatabase(t *testing.T) {
|
||||
router := createTestRouterWithStorageForTransfer()
|
||||
|
||||
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
|
||||
workspaceA := workspaces_testing.CreateTestWorkspace("Workspace A", owner1, router)
|
||||
workspaceB := workspaces_testing.CreateTestWorkspace("Workspace B", owner2, router)
|
||||
|
||||
databaseA := createTestDatabaseViaAPI("Database A", workspaceA.ID, owner1.Token, router)
|
||||
|
||||
// Test 1: Regular storage from workspace B cannot be used by database in workspace A
|
||||
regularStorageB := createTestStorage(workspaceB.ID)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfigWithRegularStorage := BackupConfig{
|
||||
DatabaseID: databaseA.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
StorageID: ®ularStorageB.ID,
|
||||
Storage: regularStorageB,
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
}
|
||||
|
||||
respRegular := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner1.Token,
|
||||
backupConfigWithRegularStorage,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
assert.Contains(t, string(respRegular.Body), "storage does not belong to the same workspace")
|
||||
|
||||
// Test 2: System storage from workspace B CAN be used by database in workspace A
|
||||
systemStorageB := &storages.Storage{
|
||||
WorkspaceID: workspaceB.ID,
|
||||
Type: storages.StorageTypeLocal,
|
||||
Name: "Test System Storage " + uuid.New().String(),
|
||||
IsSystem: true,
|
||||
LocalStorage: &local_storage.LocalStorage{},
|
||||
}
|
||||
|
||||
var savedSystemStorage storages.Storage
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/storages",
|
||||
"Bearer "+admin.Token,
|
||||
*systemStorageB,
|
||||
http.StatusOK,
|
||||
&savedSystemStorage,
|
||||
)
|
||||
|
||||
assert.True(t, savedSystemStorage.IsSystem)
|
||||
|
||||
backupConfigWithSystemStorage := BackupConfig{
|
||||
DatabaseID: databaseA.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
StorageID: &savedSystemStorage.ID,
|
||||
Storage: &savedSystemStorage,
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
}
|
||||
|
||||
var savedConfig BackupConfig
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner1.Token,
|
||||
backupConfigWithSystemStorage,
|
||||
http.StatusOK,
|
||||
&savedConfig,
|
||||
)
|
||||
|
||||
assert.Equal(t, databaseA.ID, savedConfig.DatabaseID)
|
||||
assert.NotNil(t, savedConfig.StorageID)
|
||||
assert.Equal(t, savedSystemStorage.ID, *savedConfig.StorageID)
|
||||
assert.True(t, savedConfig.IsBackupsEnabled)
|
||||
|
||||
// Cleanup: database first (cascades to backup_config), then storages, then workspaces
|
||||
databases.RemoveTestDatabase(databaseA)
|
||||
storages.RemoveTestStorage(regularStorageB.ID)
|
||||
storages.RemoveTestStorage(savedSystemStorage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspaceA, router)
|
||||
workspaces_testing.RemoveTestWorkspace(workspaceB, router)
|
||||
}
|
||||
|
||||
func createTestDatabaseViaAPI(
|
||||
name string,
|
||||
workspaceID uuid.UUID,
|
||||
@@ -1462,7 +1900,7 @@ func createTestDatabaseViaAPI(
|
||||
Type: databases.DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
package backups_config
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
plans "databasus-backend/internal/features/plan"
|
||||
"databasus-backend/internal/features/storages"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var backupConfigRepository = &BackupConfigRepository{}
|
||||
@@ -14,6 +19,7 @@ var backupConfigService = &BackupConfigService{
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
plans.GetDatabasePlanService(),
|
||||
nil,
|
||||
}
|
||||
var backupConfigController = &BackupConfigController{
|
||||
@@ -28,6 +34,21 @@ func GetBackupConfigService() *BackupConfigService {
|
||||
return backupConfigService
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
storages.GetStorageService().SetStorageDatabaseCounter(backupConfigService)
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
storages.GetStorageService().SetStorageDatabaseCounter(backupConfigService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package backups_config
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
plans "databasus-backend/internal/features/plan"
|
||||
"databasus-backend/internal/features/storages"
|
||||
"databasus-backend/internal/util/period"
|
||||
"errors"
|
||||
@@ -31,6 +33,11 @@ type BackupConfig struct {
|
||||
MaxFailedTriesCount int `json:"maxFailedTriesCount" gorm:"column:max_failed_tries_count;type:int;not null"`
|
||||
|
||||
Encryption BackupEncryption `json:"encryption" gorm:"column:encryption;type:text;not null;default:'NONE'"`
|
||||
|
||||
// MaxBackupSizeMB limits individual backup size. 0 = unlimited.
|
||||
MaxBackupSizeMB int64 `json:"maxBackupSizeMb" gorm:"column:max_backup_size_mb;type:int;not null"`
|
||||
// MaxBackupsTotalSizeMB limits total size of all backups. 0 = unlimited.
|
||||
MaxBackupsTotalSizeMB int64 `json:"maxBackupsTotalSizeMb" gorm:"column:max_backups_total_size_mb;type:int;not null"`
|
||||
}
|
||||
|
||||
func (h *BackupConfig) TableName() string {
|
||||
@@ -70,7 +77,7 @@ func (b *BackupConfig) AfterFind(tx *gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *BackupConfig) Validate() error {
|
||||
func (b *BackupConfig) Validate(plan *plans.DatabasePlan) error {
|
||||
// Backup interval is required either as ID or as object
|
||||
if b.BackupIntervalID == uuid.Nil && b.BackupInterval == nil {
|
||||
return errors.New("backup interval is required")
|
||||
@@ -89,20 +96,59 @@ func (b *BackupConfig) Validate() error {
|
||||
return errors.New("encryption must be NONE or ENCRYPTED")
|
||||
}
|
||||
|
||||
if config.GetEnv().IsCloud {
|
||||
if b.Encryption != BackupEncryptionEncrypted {
|
||||
return errors.New("encryption is mandatory for cloud storage")
|
||||
}
|
||||
}
|
||||
|
||||
if b.MaxBackupSizeMB < 0 {
|
||||
return errors.New("max backup size must be non-negative")
|
||||
}
|
||||
|
||||
if b.MaxBackupsTotalSizeMB < 0 {
|
||||
return errors.New("max backups total size must be non-negative")
|
||||
}
|
||||
|
||||
// Validate against plan limits
|
||||
// Check storage period limit
|
||||
if plan.MaxStoragePeriod != period.PeriodForever {
|
||||
if b.StorePeriod.CompareTo(plan.MaxStoragePeriod) > 0 {
|
||||
return errors.New("storage period exceeds plan limit")
|
||||
}
|
||||
}
|
||||
|
||||
// Check max backup size limit (0 in plan means unlimited)
|
||||
if plan.MaxBackupSizeMB > 0 {
|
||||
if b.MaxBackupSizeMB == 0 || b.MaxBackupSizeMB > plan.MaxBackupSizeMB {
|
||||
return errors.New("max backup size exceeds plan limit")
|
||||
}
|
||||
}
|
||||
|
||||
// Check max total backups size limit (0 in plan means unlimited)
|
||||
if plan.MaxBackupsTotalSizeMB > 0 {
|
||||
if b.MaxBackupsTotalSizeMB == 0 ||
|
||||
b.MaxBackupsTotalSizeMB > plan.MaxBackupsTotalSizeMB {
|
||||
return errors.New("max total backups size exceeds plan limit")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *BackupConfig) Copy(newDatabaseID uuid.UUID) *BackupConfig {
|
||||
return &BackupConfig{
|
||||
DatabaseID: newDatabaseID,
|
||||
IsBackupsEnabled: b.IsBackupsEnabled,
|
||||
StorePeriod: b.StorePeriod,
|
||||
BackupIntervalID: uuid.Nil,
|
||||
BackupInterval: b.BackupInterval.Copy(),
|
||||
StorageID: b.StorageID,
|
||||
SendNotificationsOn: b.SendNotificationsOn,
|
||||
IsRetryIfFailed: b.IsRetryIfFailed,
|
||||
MaxFailedTriesCount: b.MaxFailedTriesCount,
|
||||
Encryption: b.Encryption,
|
||||
DatabaseID: newDatabaseID,
|
||||
IsBackupsEnabled: b.IsBackupsEnabled,
|
||||
StorePeriod: b.StorePeriod,
|
||||
BackupIntervalID: uuid.Nil,
|
||||
BackupInterval: b.BackupInterval.Copy(),
|
||||
StorageID: b.StorageID,
|
||||
SendNotificationsOn: b.SendNotificationsOn,
|
||||
IsRetryIfFailed: b.IsRetryIfFailed,
|
||||
MaxFailedTriesCount: b.MaxFailedTriesCount,
|
||||
Encryption: b.Encryption,
|
||||
MaxBackupSizeMB: b.MaxBackupSizeMB,
|
||||
MaxBackupsTotalSizeMB: b.MaxBackupsTotalSizeMB,
|
||||
}
|
||||
}
|
||||
|
||||
391
backend/internal/features/backups/config/model_test.go
Normal file
391
backend/internal/features/backups/config/model_test.go
Normal file
@@ -0,0 +1,391 @@
|
||||
package backups_config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"databasus-backend/internal/features/intervals"
|
||||
plans "databasus-backend/internal/features/plan"
|
||||
"databasus-backend/internal/util/period"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_Validate_WhenStoragePeriodIsWeekAndPlanAllowsMonth_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.StorePeriod = period.PeriodWeek
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = period.PeriodMonth
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenStoragePeriodIsYearAndPlanAllowsMonth_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.StorePeriod = period.PeriodYear
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = period.PeriodMonth
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "storage period exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenStoragePeriodIsForeverAndPlanAllowsForever_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.StorePeriod = period.PeriodForever
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = period.PeriodForever
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenStoragePeriodIsForeverAndPlanAllowsYear_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.StorePeriod = period.PeriodForever
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = period.PeriodYear
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "storage period exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenStoragePeriodEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.StorePeriod = period.PeriodMonth
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = period.PeriodMonth
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenBackupSize100MBAndPlanAllows500MB_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupSizeMB = 100
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupSizeMB = 500
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenBackupSize500MBAndPlanAllows100MB_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupSizeMB = 500
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupSizeMB = 100
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max backup size exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenBackupSizeIsUnlimitedAndPlanAllowsUnlimited_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupSizeMB = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupSizeMB = 0
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenBackupSizeIsUnlimitedAndPlanHas500MBLimit_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupSizeMB = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupSizeMB = 500
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max backup size exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenBackupSizeEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupSizeMB = 500
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupSizeMB = 500
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenTotalSize1GBAndPlanAllows5GB_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupsTotalSizeMB = 1000
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupsTotalSizeMB = 5000
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenTotalSize5GBAndPlanAllows1GB_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupsTotalSizeMB = 5000
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupsTotalSizeMB = 1000
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max total backups size exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenTotalSizeIsUnlimitedAndPlanAllowsUnlimited_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupsTotalSizeMB = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupsTotalSizeMB = 0
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenTotalSizeIsUnlimitedAndPlanHas1GBLimit_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupsTotalSizeMB = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupsTotalSizeMB = 1000
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max total backups size exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenTotalSizeEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupsTotalSizeMB = 5000
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupsTotalSizeMB = 5000
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenAllLimitsAreUnlimitedInPlan_AnyConfigurationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.StorePeriod = period.PeriodForever
|
||||
config.MaxBackupSizeMB = 0
|
||||
config.MaxBackupsTotalSizeMB = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenMultipleLimitsExceeded_ValidationFailsWithFirstError(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.StorePeriod = period.PeriodYear
|
||||
config.MaxBackupSizeMB = 500
|
||||
config.MaxBackupsTotalSizeMB = 5000
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = period.PeriodMonth
|
||||
plan.MaxBackupSizeMB = 100
|
||||
plan.MaxBackupsTotalSizeMB = 1000
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.Error(t, err)
|
||||
assert.EqualError(t, err, "storage period exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenConfigHasInvalidIntervalButPlanIsValid_ValidationFailsOnInterval(
|
||||
t *testing.T,
|
||||
) {
|
||||
config := createValidBackupConfig()
|
||||
config.BackupIntervalID = uuid.Nil
|
||||
config.BackupInterval = nil
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "backup interval is required")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenIntervalIsMissing_ValidationFailsRegardlessOfPlan(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.BackupIntervalID = uuid.Nil
|
||||
config.BackupInterval = nil
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "backup interval is required")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenRetryEnabledButMaxTriesIsZero_ValidationFailsRegardlessOfPlan(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.IsRetryIfFailed = true
|
||||
config.MaxFailedTriesCount = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max failed tries count must be greater than 0")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenEncryptionIsInvalid_ValidationFailsRegardlessOfPlan(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.Encryption = "INVALID"
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "encryption must be NONE or ENCRYPTED")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenStoragePeriodIsEmpty_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.StorePeriod = ""
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "store period is required")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenMaxBackupSizeIsNegative_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupSizeMB = -100
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max backup size must be non-negative")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenMaxTotalSizeIsNegative_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupsTotalSizeMB = -1000
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max backups total size must be non-negative")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenPlanLimitsAreAtBoundary_ValidationWorks(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
configPeriod period.Period
|
||||
planPeriod period.Period
|
||||
configSize int64
|
||||
planSize int64
|
||||
configTotal int64
|
||||
planTotal int64
|
||||
shouldSucceed bool
|
||||
}{
|
||||
{
|
||||
name: "all values just under limit",
|
||||
configPeriod: period.PeriodWeek,
|
||||
planPeriod: period.PeriodMonth,
|
||||
configSize: 99,
|
||||
planSize: 100,
|
||||
configTotal: 999,
|
||||
planTotal: 1000,
|
||||
shouldSucceed: true,
|
||||
},
|
||||
{
|
||||
name: "all values equal to limit",
|
||||
configPeriod: period.PeriodMonth,
|
||||
planPeriod: period.PeriodMonth,
|
||||
configSize: 100,
|
||||
planSize: 100,
|
||||
configTotal: 1000,
|
||||
planTotal: 1000,
|
||||
shouldSucceed: true,
|
||||
},
|
||||
{
|
||||
name: "period just over limit",
|
||||
configPeriod: period.Period3Month,
|
||||
planPeriod: period.PeriodMonth,
|
||||
configSize: 100,
|
||||
planSize: 100,
|
||||
configTotal: 1000,
|
||||
planTotal: 1000,
|
||||
shouldSucceed: false,
|
||||
},
|
||||
{
|
||||
name: "size just over limit",
|
||||
configPeriod: period.PeriodMonth,
|
||||
planPeriod: period.PeriodMonth,
|
||||
configSize: 101,
|
||||
planSize: 100,
|
||||
configTotal: 1000,
|
||||
planTotal: 1000,
|
||||
shouldSucceed: false,
|
||||
},
|
||||
{
|
||||
name: "total size just over limit",
|
||||
configPeriod: period.PeriodMonth,
|
||||
planPeriod: period.PeriodMonth,
|
||||
configSize: 100,
|
||||
planSize: 100,
|
||||
configTotal: 1001,
|
||||
planTotal: 1000,
|
||||
shouldSucceed: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.StorePeriod = tt.configPeriod
|
||||
config.MaxBackupSizeMB = tt.configSize
|
||||
config.MaxBackupsTotalSizeMB = tt.configTotal
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = tt.planPeriod
|
||||
plan.MaxBackupSizeMB = tt.planSize
|
||||
plan.MaxBackupsTotalSizeMB = tt.planTotal
|
||||
|
||||
err := config.Validate(plan)
|
||||
if tt.shouldSucceed {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createValidBackupConfig() *BackupConfig {
|
||||
intervalID := uuid.New()
|
||||
return &BackupConfig{
|
||||
DatabaseID: uuid.New(),
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodMonth,
|
||||
BackupIntervalID: intervalID,
|
||||
BackupInterval: &intervals.Interval{ID: intervalID},
|
||||
SendNotificationsOn: []BackupNotificationType{},
|
||||
IsRetryIfFailed: false,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
MaxBackupSizeMB: 100,
|
||||
MaxBackupsTotalSizeMB: 1000,
|
||||
}
|
||||
}
|
||||
|
||||
func createUnlimitedPlan() *plans.DatabasePlan {
|
||||
return &plans.DatabasePlan{
|
||||
DatabaseID: uuid.New(),
|
||||
MaxBackupSizeMB: 0,
|
||||
MaxBackupsTotalSizeMB: 0,
|
||||
MaxStoragePeriod: period.PeriodForever,
|
||||
}
|
||||
}
|
||||
@@ -26,6 +26,12 @@ func Test_AttachNotifierFromSameWorkspace_SuccessfullyAttached(t *testing.T) {
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
database.Notifiers = []notifiers.Notifier{*notifier}
|
||||
|
||||
var response databases.Database
|
||||
@@ -55,6 +61,13 @@ func Test_AttachNotifierFromDifferentWorkspace_ReturnsForbidden(t *testing.T) {
|
||||
workspace2 := workspaces_testing.CreateTestWorkspace("Workspace 2", owner2, router)
|
||||
notifier := notifiers.CreateTestNotifier(workspace2.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace1, router)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace2, router)
|
||||
}()
|
||||
|
||||
database.Notifiers = []notifiers.Notifier{*notifier}
|
||||
|
||||
testResp := test_utils.MakePostRequest(
|
||||
@@ -77,6 +90,12 @@ func Test_DeleteNotifierWithAttachedDatabases_CannotDelete(t *testing.T) {
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
database.Notifiers = []notifiers.Notifier{*notifier}
|
||||
|
||||
var response databases.Database
|
||||
@@ -114,6 +133,13 @@ func Test_TransferNotifierWithAttachedDatabase_CannotTransfer(t *testing.T) {
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
|
||||
}()
|
||||
|
||||
database.Notifiers = []notifiers.Notifier{*notifier}
|
||||
|
||||
var response databases.Database
|
||||
|
||||
@@ -6,10 +6,10 @@ import (
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
plans "databasus-backend/internal/features/plan"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_models "databasus-backend/internal/features/users/models"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/period"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -20,6 +20,7 @@ type BackupConfigService struct {
|
||||
storageService *storages.StorageService
|
||||
notifierService *notifiers.NotifierService
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
databasePlanService *plans.DatabasePlanService
|
||||
|
||||
dbStorageChangeListener BackupConfigStorageChangeListener
|
||||
}
|
||||
@@ -45,7 +46,12 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
|
||||
user *users_models.User,
|
||||
backupConfig *BackupConfig,
|
||||
) (*BackupConfig, error) {
|
||||
if err := backupConfig.Validate(); err != nil {
|
||||
plan, err := s.databasePlanService.GetDatabasePlan(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := backupConfig.Validate(plan); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -71,7 +77,7 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if storage.WorkspaceID != *database.WorkspaceID {
|
||||
if storage.WorkspaceID != *database.WorkspaceID && !storage.IsSystem {
|
||||
return nil, errors.New("storage does not belong to the same workspace as the database")
|
||||
}
|
||||
}
|
||||
@@ -82,7 +88,12 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
|
||||
func (s *BackupConfigService) SaveBackupConfig(
|
||||
backupConfig *BackupConfig,
|
||||
) (*BackupConfig, error) {
|
||||
if err := backupConfig.Validate(); err != nil {
|
||||
plan, err := s.databasePlanService.GetDatabasePlan(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := backupConfig.Validate(plan); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -120,6 +131,18 @@ func (s *BackupConfigService) GetBackupConfigByDbIdWithAuth(
|
||||
return s.GetBackupConfigByDbId(databaseID)
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) GetDatabasePlan(
|
||||
user *users_models.User,
|
||||
databaseID uuid.UUID,
|
||||
) (*plans.DatabasePlan, error) {
|
||||
_, err := s.databaseService.GetDatabase(user, databaseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.databasePlanService.GetDatabasePlan(databaseID)
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) GetBackupConfigByDbId(
|
||||
databaseID uuid.UUID,
|
||||
) (*BackupConfig, error) {
|
||||
@@ -194,12 +217,19 @@ func (s *BackupConfigService) CreateDisabledBackupConfig(databaseID uuid.UUID) e
|
||||
func (s *BackupConfigService) initializeDefaultConfig(
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
plan, err := s.databasePlanService.GetDatabasePlan(databaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
timeOfDay := "04:00"
|
||||
|
||||
_, err := s.backupConfigRepository.Save(&BackupConfig{
|
||||
DatabaseID: databaseID,
|
||||
IsBackupsEnabled: false,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
_, err = s.backupConfigRepository.Save(&BackupConfig{
|
||||
DatabaseID: databaseID,
|
||||
IsBackupsEnabled: false,
|
||||
StorePeriod: plan.MaxStoragePeriod,
|
||||
MaxBackupSizeMB: plan.MaxBackupSizeMB,
|
||||
MaxBackupsTotalSizeMB: plan.MaxBackupsTotalSizeMB,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
|
||||
@@ -27,6 +27,12 @@ func Test_AttachStorageFromSameWorkspace_SuccessfullyAttached(t *testing.T) {
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
timeOfDay := "04:00"
|
||||
request := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
@@ -72,6 +78,13 @@ func Test_AttachStorageFromDifferentWorkspace_ReturnsForbidden(t *testing.T) {
|
||||
workspace2 := workspaces_testing.CreateTestWorkspace("Workspace 2", owner2, router)
|
||||
storage := createTestStorage(workspace2.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace1, router)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace2, router)
|
||||
}()
|
||||
|
||||
timeOfDay := "04:00"
|
||||
request := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
@@ -110,6 +123,12 @@ func Test_DeleteStorageWithAttachedDatabases_CannotDelete(t *testing.T) {
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
timeOfDay := "04:00"
|
||||
request := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
@@ -163,6 +182,13 @@ func Test_TransferStorageWithAttachedDatabase_CannotTransfer(t *testing.T) {
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
|
||||
}()
|
||||
|
||||
timeOfDay := "04:00"
|
||||
request := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
|
||||
@@ -25,80 +25,6 @@ import (
|
||||
"databasus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
router := workspaces_testing.CreateTestRouter(
|
||||
workspaces_controllers.GetWorkspaceController(),
|
||||
workspaces_controllers.GetMembershipController(),
|
||||
GetDatabaseController(),
|
||||
)
|
||||
return router
|
||||
}
|
||||
|
||||
func getTestPostgresConfig() *postgresql.PostgresqlDatabase {
|
||||
env := config.GetEnv()
|
||||
port, err := strconv.Atoi(env.TestPostgres16Port)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
|
||||
}
|
||||
|
||||
testDbName := "testdb"
|
||||
return &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func getTestMariadbConfig() *mariadb.MariadbDatabase {
|
||||
env := config.GetEnv()
|
||||
portStr := env.TestMariadb1011Port
|
||||
if portStr == "" {
|
||||
portStr = "33111"
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_MARIADB_1011_PORT: %v", err))
|
||||
}
|
||||
|
||||
testDbName := "testdb"
|
||||
return &mariadb.MariadbDatabase{
|
||||
Version: tools.MariadbVersion1011,
|
||||
Host: "localhost",
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
Database: &testDbName,
|
||||
}
|
||||
}
|
||||
|
||||
func getTestMongodbConfig() *mongodb.MongodbDatabase {
|
||||
env := config.GetEnv()
|
||||
portStr := env.TestMongodb70Port
|
||||
if portStr == "" {
|
||||
portStr = "27070"
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_MONGODB_70_PORT: %v", err))
|
||||
}
|
||||
|
||||
return &mongodb.MongodbDatabase{
|
||||
Version: tools.MongodbVersion7,
|
||||
Host: "localhost",
|
||||
Port: port,
|
||||
Username: "root",
|
||||
Password: "rootpassword",
|
||||
Database: "testdb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CreateDatabase_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -142,6 +68,7 @@ func Test_CreateDatabase_PermissionsEnforced(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
@@ -180,6 +107,7 @@ func Test_CreateDatabase_PermissionsEnforced(t *testing.T) {
|
||||
)
|
||||
|
||||
if tt.expectSuccess {
|
||||
defer RemoveTestDatabase(&response)
|
||||
assert.Equal(t, "Test Database", response.Name)
|
||||
assert.NotEqual(t, uuid.Nil, response.ID)
|
||||
} else {
|
||||
@@ -193,6 +121,7 @@ func Test_CreateDatabase_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testin
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
|
||||
@@ -258,8 +187,10 @@ func Test_UpdateDatabase_PermissionsEnforced(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
defer RemoveTestDatabase(database)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
@@ -305,8 +236,10 @@ func Test_UpdateDatabase_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testin
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
defer RemoveTestDatabase(database)
|
||||
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
database.Name = "Hacked Name"
|
||||
@@ -366,6 +299,7 @@ func Test_DeleteDatabase_PermissionsEnforced(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
@@ -396,6 +330,7 @@ func Test_DeleteDatabase_PermissionsEnforced(t *testing.T) {
|
||||
)
|
||||
|
||||
if !tt.expectSuccess {
|
||||
defer RemoveTestDatabase(database)
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
})
|
||||
@@ -439,8 +374,10 @@ func Test_GetDatabase_PermissionsEnforced(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
defer RemoveTestDatabase(database)
|
||||
|
||||
var testUser string
|
||||
if tt.isGlobalAdmin {
|
||||
@@ -517,9 +454,12 @@ func Test_GetDatabasesByWorkspace_PermissionsEnforced(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
createTestDatabaseViaAPI("Database 1", workspace.ID, owner.Token, router)
|
||||
createTestDatabaseViaAPI("Database 2", workspace.ID, owner.Token, router)
|
||||
db1 := createTestDatabaseViaAPI("Database 1", workspace.ID, owner.Token, router)
|
||||
defer RemoveTestDatabase(db1)
|
||||
db2 := createTestDatabaseViaAPI("Database 2", workspace.ID, owner.Token, router)
|
||||
defer RemoveTestDatabase(db2)
|
||||
|
||||
var testUser string
|
||||
if tt.isGlobalAdmin {
|
||||
@@ -561,10 +501,14 @@ func Test_GetDatabasesByWorkspace_WhenMultipleDatabasesExist_ReturnsCorrectCount
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
createTestDatabaseViaAPI("Database 1", workspace.ID, owner.Token, router)
|
||||
createTestDatabaseViaAPI("Database 2", workspace.ID, owner.Token, router)
|
||||
createTestDatabaseViaAPI("Database 3", workspace.ID, owner.Token, router)
|
||||
db1 := createTestDatabaseViaAPI("Database 1", workspace.ID, owner.Token, router)
|
||||
defer RemoveTestDatabase(db1)
|
||||
db2 := createTestDatabaseViaAPI("Database 2", workspace.ID, owner.Token, router)
|
||||
defer RemoveTestDatabase(db2)
|
||||
db3 := createTestDatabaseViaAPI("Database 3", workspace.ID, owner.Token, router)
|
||||
defer RemoveTestDatabase(db3)
|
||||
|
||||
var response []Database
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
@@ -583,14 +527,19 @@ func Test_GetDatabasesByWorkspace_EnsuresCrossWorkspaceIsolation(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace1 := workspaces_testing.CreateTestWorkspace("Workspace 1", owner1, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace1, router)
|
||||
|
||||
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace2 := workspaces_testing.CreateTestWorkspace("Workspace 2", owner2, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace2, router)
|
||||
|
||||
createTestDatabaseViaAPI("Workspace1 DB1", workspace1.ID, owner1.Token, router)
|
||||
createTestDatabaseViaAPI("Workspace1 DB2", workspace1.ID, owner1.Token, router)
|
||||
workspace1Db1 := createTestDatabaseViaAPI("Workspace1 DB1", workspace1.ID, owner1.Token, router)
|
||||
defer RemoveTestDatabase(workspace1Db1)
|
||||
workspace1Db2 := createTestDatabaseViaAPI("Workspace1 DB2", workspace1.ID, owner1.Token, router)
|
||||
defer RemoveTestDatabase(workspace1Db2)
|
||||
|
||||
createTestDatabaseViaAPI("Workspace2 DB1", workspace2.ID, owner2.Token, router)
|
||||
workspace2Db1 := createTestDatabaseViaAPI("Workspace2 DB1", workspace2.ID, owner2.Token, router)
|
||||
defer RemoveTestDatabase(workspace2Db1)
|
||||
|
||||
var workspace1Dbs []Database
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
@@ -667,8 +616,10 @@ func Test_CopyDatabase_PermissionsEnforced(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
defer RemoveTestDatabase(database)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
@@ -700,6 +651,7 @@ func Test_CopyDatabase_PermissionsEnforced(t *testing.T) {
|
||||
)
|
||||
|
||||
if tt.expectSuccess {
|
||||
defer RemoveTestDatabase(&response)
|
||||
assert.NotEqual(t, database.ID, response.ID)
|
||||
assert.Contains(t, response.Name, "(Copy)")
|
||||
} else {
|
||||
@@ -713,8 +665,10 @@ func Test_CopyDatabase_CopyStaysInSameWorkspace(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
defer RemoveTestDatabase(database)
|
||||
|
||||
var response Database
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
@@ -727,139 +681,14 @@ func Test_CopyDatabase_CopyStaysInSameWorkspace(t *testing.T) {
|
||||
&response,
|
||||
)
|
||||
|
||||
defer RemoveTestDatabase(&response)
|
||||
|
||||
assert.NotEqual(t, database.ID, response.ID)
|
||||
assert.Equal(t, "Test Database (Copy)", response.Name)
|
||||
assert.Equal(t, workspace.ID, *response.WorkspaceID)
|
||||
assert.Equal(t, database.Type, response.Type)
|
||||
}
|
||||
|
||||
func Test_TestConnection_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
isMember bool
|
||||
isGlobalAdmin bool
|
||||
expectAccessGranted bool
|
||||
expectedStatusCodeOnErr int
|
||||
}{
|
||||
{
|
||||
name: "workspace member can test connection",
|
||||
isMember: true,
|
||||
isGlobalAdmin: false,
|
||||
expectAccessGranted: true,
|
||||
expectedStatusCodeOnErr: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "non-member cannot test connection",
|
||||
isMember: false,
|
||||
isGlobalAdmin: false,
|
||||
expectAccessGranted: false,
|
||||
expectedStatusCodeOnErr: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "global admin can test connection",
|
||||
isMember: false,
|
||||
isGlobalAdmin: true,
|
||||
expectAccessGranted: true,
|
||||
expectedStatusCodeOnErr: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
var testUser string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
testUser = admin.Token
|
||||
} else if tt.isMember {
|
||||
testUser = owner.Token
|
||||
} else {
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
testUser = nonMember.Token
|
||||
}
|
||||
|
||||
w := workspaces_testing.MakeAPIRequest(
|
||||
router,
|
||||
"POST",
|
||||
"/api/v1/databases/"+database.ID.String()+"/test-connection",
|
||||
"Bearer "+testUser,
|
||||
nil,
|
||||
)
|
||||
|
||||
body := w.Body.String()
|
||||
|
||||
if tt.expectAccessGranted {
|
||||
assert.True(
|
||||
t,
|
||||
w.Code == http.StatusOK ||
|
||||
(w.Code == http.StatusBadRequest && strings.Contains(body, "connect")),
|
||||
"Expected 200 OK or 400 with connection error, got %d: %s",
|
||||
w.Code,
|
||||
body,
|
||||
)
|
||||
} else {
|
||||
assert.Equal(t, tt.expectedStatusCodeOnErr, w.Code)
|
||||
assert.Contains(t, body, "insufficient permissions")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createTestDatabaseViaAPI(
|
||||
name string,
|
||||
workspaceID uuid.UUID,
|
||||
token string,
|
||||
router *gin.Engine,
|
||||
) *Database {
|
||||
env := config.GetEnv()
|
||||
port, err := strconv.Atoi(env.TestPostgres16Port)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
|
||||
}
|
||||
|
||||
testDbName := "testdb"
|
||||
request := Database{
|
||||
Name: name,
|
||||
WorkspaceID: &workspaceID,
|
||||
Type: DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
},
|
||||
}
|
||||
|
||||
w := workspaces_testing.MakeAPIRequest(
|
||||
router,
|
||||
"POST",
|
||||
"/api/v1/databases/create",
|
||||
"Bearer "+token,
|
||||
request,
|
||||
)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
panic(
|
||||
fmt.Sprintf("Failed to create database. Status: %d, Body: %s", w.Code, w.Body.String()),
|
||||
)
|
||||
}
|
||||
|
||||
var database Database
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return &database
|
||||
}
|
||||
|
||||
func Test_CreateDatabase_PasswordIsEncryptedInDB(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
@@ -1141,3 +970,207 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_TestConnection_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
isMember bool
|
||||
isGlobalAdmin bool
|
||||
expectAccessGranted bool
|
||||
expectedStatusCodeOnErr int
|
||||
}{
|
||||
{
|
||||
name: "workspace member can test connection",
|
||||
isMember: true,
|
||||
isGlobalAdmin: false,
|
||||
expectAccessGranted: true,
|
||||
expectedStatusCodeOnErr: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "non-member cannot test connection",
|
||||
isMember: false,
|
||||
isGlobalAdmin: false,
|
||||
expectAccessGranted: false,
|
||||
expectedStatusCodeOnErr: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "global admin can test connection",
|
||||
isMember: false,
|
||||
isGlobalAdmin: true,
|
||||
expectAccessGranted: true,
|
||||
expectedStatusCodeOnErr: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
defer RemoveTestDatabase(database)
|
||||
|
||||
var testUser string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
testUser = admin.Token
|
||||
} else if tt.isMember {
|
||||
testUser = owner.Token
|
||||
} else {
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
testUser = nonMember.Token
|
||||
}
|
||||
|
||||
w := workspaces_testing.MakeAPIRequest(
|
||||
router,
|
||||
"POST",
|
||||
"/api/v1/databases/"+database.ID.String()+"/test-connection",
|
||||
"Bearer "+testUser,
|
||||
nil,
|
||||
)
|
||||
|
||||
body := w.Body.String()
|
||||
|
||||
if tt.expectAccessGranted {
|
||||
assert.True(
|
||||
t,
|
||||
w.Code == http.StatusOK ||
|
||||
(w.Code == http.StatusBadRequest && strings.Contains(body, "connect")),
|
||||
"Expected 200 OK or 400 with connection error, got %d: %s",
|
||||
w.Code,
|
||||
body,
|
||||
)
|
||||
} else {
|
||||
assert.Equal(t, tt.expectedStatusCodeOnErr, w.Code)
|
||||
assert.Contains(t, body, "insufficient permissions")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createTestDatabaseViaAPI(
|
||||
name string,
|
||||
workspaceID uuid.UUID,
|
||||
token string,
|
||||
router *gin.Engine,
|
||||
) *Database {
|
||||
env := config.GetEnv()
|
||||
port, err := strconv.Atoi(env.TestPostgres16Port)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
|
||||
}
|
||||
|
||||
testDbName := "testdb"
|
||||
request := Database{
|
||||
Name: name,
|
||||
WorkspaceID: &workspaceID,
|
||||
Type: DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
},
|
||||
}
|
||||
|
||||
w := workspaces_testing.MakeAPIRequest(
|
||||
router,
|
||||
"POST",
|
||||
"/api/v1/databases/create",
|
||||
"Bearer "+token,
|
||||
request,
|
||||
)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
panic(
|
||||
fmt.Sprintf("Failed to create database. Status: %d, Body: %s", w.Code, w.Body.String()),
|
||||
)
|
||||
}
|
||||
|
||||
var database Database
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return &database
|
||||
}
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
router := workspaces_testing.CreateTestRouter(
|
||||
workspaces_controllers.GetWorkspaceController(),
|
||||
workspaces_controllers.GetMembershipController(),
|
||||
GetDatabaseController(),
|
||||
)
|
||||
return router
|
||||
}
|
||||
|
||||
func getTestPostgresConfig() *postgresql.PostgresqlDatabase {
|
||||
env := config.GetEnv()
|
||||
port, err := strconv.Atoi(env.TestPostgres16Port)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
|
||||
}
|
||||
|
||||
testDbName := "testdb"
|
||||
return &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func getTestMariadbConfig() *mariadb.MariadbDatabase {
|
||||
env := config.GetEnv()
|
||||
portStr := env.TestMariadb1011Port
|
||||
if portStr == "" {
|
||||
portStr = "33111"
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_MARIADB_1011_PORT: %v", err))
|
||||
}
|
||||
|
||||
testDbName := "testdb"
|
||||
return &mariadb.MariadbDatabase{
|
||||
Version: tools.MariadbVersion1011,
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
Database: &testDbName,
|
||||
}
|
||||
}
|
||||
|
||||
func getTestMongodbConfig() *mongodb.MongodbDatabase {
|
||||
env := config.GetEnv()
|
||||
portStr := env.TestMongodb70Port
|
||||
if portStr == "" {
|
||||
portStr = "27070"
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_MONGODB_70_PORT: %v", err))
|
||||
}
|
||||
|
||||
return &mongodb.MongodbDatabase{
|
||||
Version: tools.MongodbVersion7,
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: &port,
|
||||
Username: "root",
|
||||
Password: "rootpassword",
|
||||
Database: "testdb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,13 +25,14 @@ type MariadbDatabase struct {
|
||||
|
||||
Version tools.MariadbVersion `json:"version" gorm:"type:text;not null"`
|
||||
|
||||
Host string `json:"host" gorm:"type:text;not null"`
|
||||
Port int `json:"port" gorm:"type:int;not null"`
|
||||
Username string `json:"username" gorm:"type:text;not null"`
|
||||
Password string `json:"password" gorm:"type:text;not null"`
|
||||
Database *string `json:"database" gorm:"type:text"`
|
||||
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
|
||||
Privileges string `json:"privileges" gorm:"column:privileges;type:text;not null;default:''"`
|
||||
Host string `json:"host" gorm:"type:text;not null"`
|
||||
Port int `json:"port" gorm:"type:int;not null"`
|
||||
Username string `json:"username" gorm:"type:text;not null"`
|
||||
Password string `json:"password" gorm:"type:text;not null"`
|
||||
Database *string `json:"database" gorm:"type:text"`
|
||||
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
|
||||
IsExcludeEvents bool `json:"isExcludeEvents" gorm:"type:boolean;default:false"`
|
||||
Privileges string `json:"privileges" gorm:"column:privileges;type:text;not null;default:''"`
|
||||
}
|
||||
|
||||
func (m *MariadbDatabase) TableName() string {
|
||||
@@ -124,6 +125,7 @@ func (m *MariadbDatabase) Update(incoming *MariadbDatabase) {
|
||||
m.Username = incoming.Username
|
||||
m.Database = incoming.Database
|
||||
m.IsHttps = incoming.IsHttps
|
||||
m.IsExcludeEvents = incoming.IsExcludeEvents
|
||||
m.Privileges = incoming.Privileges
|
||||
|
||||
if incoming.Password != "" {
|
||||
@@ -515,11 +517,17 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
|
||||
hasProcess := false
|
||||
hasAllPrivileges := false
|
||||
|
||||
escapedDB := strings.ReplaceAll(database, "_", "\\_")
|
||||
dbPattern := regexp.MustCompile(
|
||||
fmt.Sprintf("(?i)ON\\s+[`'\"]?(%s|\\*)[`'\"]?\\.\\*", regexp.QuoteMeta(escapedDB)),
|
||||
// Escape underscores to match MariaDB's grant output format
|
||||
// MariaDB escapes _ as \_ in SHOW GRANTS output
|
||||
// Pattern matches either literal _ or escaped \_
|
||||
escapedDbName := strings.ReplaceAll(regexp.QuoteMeta(database), "_", `(_|\\_)`)
|
||||
dbPatternStr := fmt.Sprintf(
|
||||
`(?i)ON\s+[\x60'"]?%s[\x60'"]?\s*\.\s*\*`,
|
||||
escapedDbName,
|
||||
)
|
||||
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\.\*`)
|
||||
dbPattern := regexp.MustCompile(dbPatternStr)
|
||||
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\s*\.\s*\*`)
|
||||
allPrivilegesPattern := regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`)
|
||||
|
||||
for rows.Next() {
|
||||
var grant string
|
||||
@@ -527,23 +535,26 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
|
||||
return "", fmt.Errorf("failed to scan grant: %w", err)
|
||||
}
|
||||
|
||||
if regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`).MatchString(grant) {
|
||||
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
|
||||
hasAllPrivileges = true
|
||||
}
|
||||
isRelevantGrant := globalPattern.MatchString(grant) || dbPattern.MatchString(grant)
|
||||
|
||||
if allPrivilegesPattern.MatchString(grant) && isRelevantGrant {
|
||||
hasAllPrivileges = true
|
||||
}
|
||||
|
||||
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
|
||||
if isRelevantGrant {
|
||||
for _, priv := range backupPrivileges {
|
||||
if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) {
|
||||
privPattern := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(priv) + `\b`)
|
||||
if privPattern.MatchString(grant) {
|
||||
detectedPrivileges[priv] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if globalPattern.MatchString(grant) &&
|
||||
regexp.MustCompile(`(?i)\bPROCESS\b`).MatchString(grant) {
|
||||
hasProcess = true
|
||||
if globalPattern.MatchString(grant) {
|
||||
processPattern := regexp.MustCompile(`(?i)\bPROCESS\b`)
|
||||
if processPattern.MatchString(grant) {
|
||||
hasProcess = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -537,6 +537,272 @@ func Test_ReadOnlyUser_CannotDropOrAlterTables(t *testing.T) {
|
||||
dropUserSafe(container.DB, username)
|
||||
}
|
||||
|
||||
func Test_TestConnection_DatabaseSpecificPrivilegesWithGlobalProcess_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MariadbVersion
|
||||
port string
|
||||
}{
|
||||
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
|
||||
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
|
||||
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
|
||||
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
|
||||
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
|
||||
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
|
||||
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
|
||||
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
|
||||
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
|
||||
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
|
||||
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMariadbContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`DROP TABLE IF EXISTS privilege_test`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`CREATE TABLE privilege_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`INSERT INTO privilege_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
specificUsername := fmt.Sprintf("spec_%s", uuid.New().String()[:8])
|
||||
specificPassword := "specificpass123"
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
specificUsername,
|
||||
specificPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"GRANT SELECT, SHOW VIEW ON %s.* TO '%s'@'%%'",
|
||||
container.Database,
|
||||
specificUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"GRANT PROCESS ON *.* TO '%s'@'%%'",
|
||||
specificUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer dropUserSafe(container.DB, specificUsername)
|
||||
|
||||
mariadbModel := &MariadbDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: specificUsername,
|
||||
Password: specificPassword,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mariadbModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_TestConnection_DatabaseWithUnderscores_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToMariadbContainer(t, env.TestMariadb1011Port, tools.MariadbVersion1011)
|
||||
defer container.DB.Close()
|
||||
|
||||
underscoreDbName := "test_db_name"
|
||||
|
||||
_, err := container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", underscoreDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName))
|
||||
}()
|
||||
|
||||
underscoreDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
container.Username, container.Password, container.Host, container.Port, underscoreDbName)
|
||||
underscoreDB, err := sqlx.Connect("mysql", underscoreDSN)
|
||||
assert.NoError(t, err)
|
||||
defer underscoreDB.Close()
|
||||
|
||||
_, err = underscoreDB.Exec(`
|
||||
CREATE TABLE underscore_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(`INSERT INTO underscore_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
underscoreUsername := fmt.Sprintf("under%s", uuid.New().String()[:8])
|
||||
underscorePassword := "underscorepass123"
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
underscoreUsername,
|
||||
underscorePassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"GRANT SELECT, SHOW VIEW ON `%s`.* TO '%s'@'%%'",
|
||||
underscoreDbName,
|
||||
underscoreUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer dropUserSafe(underscoreDB, underscoreUsername)
|
||||
|
||||
mariadbModel := &MariadbDatabase{
|
||||
Version: tools.MariadbVersion1011,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: underscoreUsername,
|
||||
Password: underscorePassword,
|
||||
Database: &underscoreDbName,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mariadbModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_TestConnection_DatabaseWithUnderscoresAndAllPrivileges_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MariadbVersion
|
||||
port string
|
||||
}{
|
||||
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
|
||||
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
|
||||
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
|
||||
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
|
||||
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
|
||||
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
|
||||
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
|
||||
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
|
||||
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
|
||||
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
|
||||
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMariadbContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
underscoreDbName := "test_all_db"
|
||||
|
||||
_, err := container.DB.Exec(
|
||||
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", underscoreDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(
|
||||
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
|
||||
)
|
||||
}()
|
||||
|
||||
underscoreDSN := fmt.Sprintf(
|
||||
"%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
container.Username,
|
||||
container.Password,
|
||||
container.Host,
|
||||
container.Port,
|
||||
underscoreDbName,
|
||||
)
|
||||
underscoreDB, err := sqlx.Connect("mysql", underscoreDSN)
|
||||
assert.NoError(t, err)
|
||||
defer underscoreDB.Close()
|
||||
|
||||
_, err = underscoreDB.Exec(`
|
||||
CREATE TABLE all_priv_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(`INSERT INTO all_priv_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
allPrivUsername := fmt.Sprintf("allpriv%s", uuid.New().String()[:8])
|
||||
allPrivPassword := "allprivpass123"
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
allPrivUsername,
|
||||
allPrivPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"GRANT ALL PRIVILEGES ON `%s`.* TO '%s'@'%%'",
|
||||
underscoreDbName,
|
||||
allPrivUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer dropUserSafe(underscoreDB, allPrivUsername)
|
||||
|
||||
mariadbModel := &MariadbDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: allPrivUsername,
|
||||
Password: allPrivPassword,
|
||||
Database: &underscoreDbName,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mariadbModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, mariadbModel.Privileges)
|
||||
assert.Contains(t, mariadbModel.Privileges, "SELECT")
|
||||
assert.Contains(t, mariadbModel.Privileges, "SHOW VIEW")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type MariadbContainer struct {
|
||||
Host string
|
||||
Port int
|
||||
@@ -557,7 +823,7 @@ func connectToMariadbContainer(
|
||||
}
|
||||
|
||||
dbName := "testdb"
|
||||
host := "127.0.0.1"
|
||||
host := config.GetEnv().TestLocalhost
|
||||
username := "root"
|
||||
password := "rootpassword"
|
||||
|
||||
|
||||
@@ -26,12 +26,13 @@ type MongodbDatabase struct {
|
||||
Version tools.MongodbVersion `json:"version" gorm:"type:text;not null"`
|
||||
|
||||
Host string `json:"host" gorm:"type:text;not null"`
|
||||
Port int `json:"port" gorm:"type:int;not null"`
|
||||
Port *int `json:"port" gorm:"type:int"`
|
||||
Username string `json:"username" gorm:"type:text;not null"`
|
||||
Password string `json:"password" gorm:"type:text;not null"`
|
||||
Database string `json:"database" gorm:"type:text;not null"`
|
||||
AuthDatabase string `json:"authDatabase" gorm:"type:text;not null;default:'admin'"`
|
||||
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
|
||||
IsSrv bool `json:"isSrv" gorm:"column:is_srv;type:boolean;not null;default:false"`
|
||||
CpuCount int `json:"cpuCount" gorm:"column:cpu_count;type:int;not null;default:1"`
|
||||
}
|
||||
|
||||
@@ -43,9 +44,13 @@ func (m *MongodbDatabase) Validate() error {
|
||||
if m.Host == "" {
|
||||
return errors.New("host is required")
|
||||
}
|
||||
if m.Port == 0 {
|
||||
return errors.New("port is required")
|
||||
|
||||
if !m.IsSrv {
|
||||
if m.Port == nil || *m.Port == 0 {
|
||||
return errors.New("port is required for standard connections")
|
||||
}
|
||||
}
|
||||
|
||||
if m.Username == "" {
|
||||
return errors.New("username is required")
|
||||
}
|
||||
@@ -58,6 +63,7 @@ func (m *MongodbDatabase) Validate() error {
|
||||
if m.CpuCount <= 0 {
|
||||
return errors.New("cpu count must be greater than 0")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -125,6 +131,7 @@ func (m *MongodbDatabase) Update(incoming *MongodbDatabase) {
|
||||
m.Database = incoming.Database
|
||||
m.AuthDatabase = incoming.AuthDatabase
|
||||
m.IsHttps = incoming.IsHttps
|
||||
m.IsSrv = incoming.IsSrv
|
||||
m.CpuCount = incoming.CpuCount
|
||||
|
||||
if incoming.Password != "" {
|
||||
@@ -455,12 +462,29 @@ func (m *MongodbDatabase) buildConnectionURI(password string) string {
|
||||
tlsParams = "&tls=true&tlsInsecure=true"
|
||||
}
|
||||
|
||||
if m.IsSrv {
|
||||
return fmt.Sprintf(
|
||||
"mongodb+srv://%s:%s@%s/%s?authSource=%s&connectTimeoutMS=15000%s",
|
||||
url.QueryEscape(m.Username),
|
||||
url.QueryEscape(password),
|
||||
m.Host,
|
||||
m.Database,
|
||||
authDB,
|
||||
tlsParams,
|
||||
)
|
||||
}
|
||||
|
||||
port := 27017
|
||||
if m.Port != nil {
|
||||
port = *m.Port
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"mongodb://%s:%s@%s:%d/%s?authSource=%s&connectTimeoutMS=15000%s",
|
||||
url.QueryEscape(m.Username),
|
||||
url.QueryEscape(password),
|
||||
m.Host,
|
||||
m.Port,
|
||||
port,
|
||||
m.Database,
|
||||
authDB,
|
||||
tlsParams,
|
||||
@@ -479,12 +503,28 @@ func (m *MongodbDatabase) BuildMongodumpURI(password string) string {
|
||||
tlsParams = "&tls=true&tlsInsecure=true"
|
||||
}
|
||||
|
||||
if m.IsSrv {
|
||||
return fmt.Sprintf(
|
||||
"mongodb+srv://%s:%s@%s/?authSource=%s&connectTimeoutMS=15000%s",
|
||||
url.QueryEscape(m.Username),
|
||||
url.QueryEscape(password),
|
||||
m.Host,
|
||||
authDB,
|
||||
tlsParams,
|
||||
)
|
||||
}
|
||||
|
||||
port := 27017
|
||||
if m.Port != nil {
|
||||
port = *m.Port
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"mongodb://%s:%s@%s:%d/?authSource=%s&connectTimeoutMS=15000%s",
|
||||
url.QueryEscape(m.Username),
|
||||
url.QueryEscape(password),
|
||||
m.Host,
|
||||
m.Port,
|
||||
port,
|
||||
authDB,
|
||||
tlsParams,
|
||||
)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -63,15 +64,17 @@ func Test_TestConnection_InsufficientPermissions_ReturnsError(t *testing.T) {
|
||||
|
||||
defer dropUserSafe(container.Client, limitedUsername, container.AuthDatabase)
|
||||
|
||||
port := container.Port
|
||||
mongodbModel := &MongodbDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Port: &port,
|
||||
Username: limitedUsername,
|
||||
Password: limitedPassword,
|
||||
Database: container.Database,
|
||||
AuthDatabase: container.AuthDatabase,
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
@@ -132,15 +135,17 @@ func Test_TestConnection_SufficientPermissions_Success(t *testing.T) {
|
||||
|
||||
defer dropUserSafe(container.Client, backupUsername, container.AuthDatabase)
|
||||
|
||||
port := container.Port
|
||||
mongodbModel := &MongodbDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Port: &port,
|
||||
Username: backupUsername,
|
||||
Password: backupPassword,
|
||||
Database: container.Database,
|
||||
AuthDatabase: container.AuthDatabase,
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
@@ -397,7 +402,7 @@ func connectToMongodbContainer(
|
||||
}
|
||||
|
||||
dbName := "testdb"
|
||||
host := "127.0.0.1"
|
||||
host := config.GetEnv().TestLocalhost
|
||||
username := "root"
|
||||
password := "rootpassword"
|
||||
authDatabase := "admin"
|
||||
@@ -406,11 +411,18 @@ func connectToMongodbContainer(
|
||||
assert.NoError(t, err)
|
||||
|
||||
uri := fmt.Sprintf(
|
||||
"mongodb://%s:%s@%s:%d/%s?authSource=%s",
|
||||
username, password, host, portInt, dbName, authDatabase,
|
||||
"mongodb://%s:%s@%s:%d/%s?authSource=%s&serverSelectionTimeoutMS=5000&connectTimeoutMS=5000",
|
||||
username,
|
||||
password,
|
||||
host,
|
||||
portInt,
|
||||
dbName,
|
||||
authDatabase,
|
||||
)
|
||||
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
clientOptions := options.Client().ApplyURI(uri)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
if err != nil {
|
||||
@@ -434,15 +446,17 @@ func connectToMongodbContainer(
|
||||
}
|
||||
|
||||
func createMongodbModel(container *MongodbContainer) *MongodbDatabase {
|
||||
port := container.Port
|
||||
return &MongodbDatabase{
|
||||
Version: container.Version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Port: &port,
|
||||
Username: container.Username,
|
||||
Password: container.Password,
|
||||
Database: container.Database,
|
||||
AuthDatabase: container.AuthDatabase,
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
}
|
||||
@@ -481,3 +495,157 @@ func assertWriteDenied(t *testing.T, err error) {
|
||||
strings.Contains(errStr, "permission denied"),
|
||||
"Expected authorization error, got: %v", err)
|
||||
}
|
||||
|
||||
func Test_BuildConnectionURI_WithSrvFormat_ReturnsCorrectUri(t *testing.T) {
|
||||
port := 27017
|
||||
model := &MongodbDatabase{
|
||||
Host: "cluster0.example.mongodb.net",
|
||||
Port: &port,
|
||||
Username: "testuser",
|
||||
Password: "testpass123",
|
||||
Database: "mydb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: true,
|
||||
}
|
||||
|
||||
uri := model.buildConnectionURI("testpass123")
|
||||
|
||||
assert.Contains(t, uri, "mongodb+srv://")
|
||||
assert.Contains(t, uri, "testuser")
|
||||
assert.Contains(t, uri, "testpass123")
|
||||
assert.Contains(t, uri, "cluster0.example.mongodb.net")
|
||||
assert.Contains(t, uri, "/mydb")
|
||||
assert.Contains(t, uri, "authSource=admin")
|
||||
assert.Contains(t, uri, "connectTimeoutMS=15000")
|
||||
assert.NotContains(t, uri, ":27017")
|
||||
}
|
||||
|
||||
func Test_BuildConnectionURI_WithStandardFormat_ReturnsCorrectUri(t *testing.T) {
|
||||
port := 27017
|
||||
model := &MongodbDatabase{
|
||||
Host: "localhost",
|
||||
Port: &port,
|
||||
Username: "testuser",
|
||||
Password: "testpass123",
|
||||
Database: "mydb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
}
|
||||
|
||||
uri := model.buildConnectionURI("testpass123")
|
||||
|
||||
assert.Contains(t, uri, "mongodb://")
|
||||
assert.Contains(t, uri, "testuser")
|
||||
assert.Contains(t, uri, "testpass123")
|
||||
assert.Contains(t, uri, "localhost:27017")
|
||||
assert.Contains(t, uri, "/mydb")
|
||||
assert.Contains(t, uri, "authSource=admin")
|
||||
assert.Contains(t, uri, "connectTimeoutMS=15000")
|
||||
assert.NotContains(t, uri, "mongodb+srv://")
|
||||
}
|
||||
|
||||
func Test_BuildConnectionURI_WithNullPort_UsesDefault(t *testing.T) {
|
||||
model := &MongodbDatabase{
|
||||
Host: "localhost",
|
||||
Port: nil,
|
||||
Username: "testuser",
|
||||
Password: "testpass123",
|
||||
Database: "mydb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
}
|
||||
|
||||
uri := model.buildConnectionURI("testpass123")
|
||||
|
||||
assert.Contains(t, uri, "localhost:27017")
|
||||
}
|
||||
|
||||
func Test_BuildMongodumpURI_WithSrvFormat_ReturnsCorrectUri(t *testing.T) {
|
||||
port := 27017
|
||||
model := &MongodbDatabase{
|
||||
Host: "cluster0.example.mongodb.net",
|
||||
Port: &port,
|
||||
Username: "testuser",
|
||||
Password: "testpass123",
|
||||
Database: "mydb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: true,
|
||||
}
|
||||
|
||||
uri := model.BuildMongodumpURI("testpass123")
|
||||
|
||||
assert.Contains(t, uri, "mongodb+srv://")
|
||||
assert.Contains(t, uri, "testuser")
|
||||
assert.Contains(t, uri, "testpass123")
|
||||
assert.Contains(t, uri, "cluster0.example.mongodb.net")
|
||||
assert.Contains(t, uri, "/?authSource=admin")
|
||||
assert.Contains(t, uri, "connectTimeoutMS=15000")
|
||||
assert.NotContains(t, uri, ":27017")
|
||||
assert.NotContains(t, uri, "/mydb")
|
||||
}
|
||||
|
||||
func Test_BuildMongodumpURI_WithStandardFormat_ReturnsCorrectUri(t *testing.T) {
|
||||
port := 27017
|
||||
model := &MongodbDatabase{
|
||||
Host: "localhost",
|
||||
Port: &port,
|
||||
Username: "testuser",
|
||||
Password: "testpass123",
|
||||
Database: "mydb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
}
|
||||
|
||||
uri := model.BuildMongodumpURI("testpass123")
|
||||
|
||||
assert.Contains(t, uri, "mongodb://")
|
||||
assert.Contains(t, uri, "testuser")
|
||||
assert.Contains(t, uri, "testpass123")
|
||||
assert.Contains(t, uri, "localhost:27017")
|
||||
assert.Contains(t, uri, "/?authSource=admin")
|
||||
assert.Contains(t, uri, "connectTimeoutMS=15000")
|
||||
assert.NotContains(t, uri, "mongodb+srv://")
|
||||
assert.NotContains(t, uri, "/mydb")
|
||||
}
|
||||
|
||||
func Test_Validate_SrvConnection_AllowsNullPort(t *testing.T) {
|
||||
model := &MongodbDatabase{
|
||||
Host: "cluster0.example.mongodb.net",
|
||||
Port: nil,
|
||||
Username: "testuser",
|
||||
Password: "testpass123",
|
||||
Database: "mydb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: true,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
err := model.Validate()
|
||||
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_StandardConnection_RequiresPort(t *testing.T) {
|
||||
model := &MongodbDatabase{
|
||||
Host: "localhost",
|
||||
Port: nil,
|
||||
Username: "testuser",
|
||||
Password: "testpass123",
|
||||
Database: "mydb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
err := model.Validate()
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "port is required for standard connections")
|
||||
}
|
||||
|
||||
@@ -400,6 +400,7 @@ func HasPrivilege(privileges, priv string) bool {
|
||||
|
||||
func (m *MysqlDatabase) buildDSN(password string, database string) string {
|
||||
tlsConfig := "false"
|
||||
allowCleartext := ""
|
||||
|
||||
if m.IsHttps {
|
||||
err := mysql.RegisterTLSConfig("mysql-skip-verify", &tls.Config{
|
||||
@@ -411,16 +412,18 @@ func (m *MysqlDatabase) buildDSN(password string, database string) string {
|
||||
}
|
||||
|
||||
tlsConfig = "mysql-skip-verify"
|
||||
allowCleartext = "&allowCleartextPasswords=1"
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"%s:%s@tcp(%s:%d)/%s?parseTime=true&timeout=15s&tls=%s&charset=utf8mb4",
|
||||
"%s:%s@tcp(%s:%d)/%s?parseTime=true&timeout=15s&tls=%s&charset=utf8mb4%s",
|
||||
m.Username,
|
||||
password,
|
||||
m.Host,
|
||||
m.Port,
|
||||
database,
|
||||
tlsConfig,
|
||||
allowCleartext,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -486,11 +489,17 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
|
||||
hasProcess := false
|
||||
hasAllPrivileges := false
|
||||
|
||||
escapedDB := strings.ReplaceAll(database, "_", "\\_")
|
||||
dbPattern := regexp.MustCompile(
|
||||
fmt.Sprintf("(?i)ON\\s+[`'\"]?(%s|\\*)[`'\"]?\\.\\*", regexp.QuoteMeta(escapedDB)),
|
||||
// Escape underscores to match MySQL's grant output format
|
||||
// MySQL escapes _ as \_ in SHOW GRANTS output
|
||||
// Pattern matches either literal _ or escaped \_
|
||||
escapedDbName := strings.ReplaceAll(regexp.QuoteMeta(database), "_", `(_|\\_)`)
|
||||
dbPatternStr := fmt.Sprintf(
|
||||
`(?i)ON\s+[\x60'"]?%s[\x60'"]?\s*\.\s*\*`,
|
||||
escapedDbName,
|
||||
)
|
||||
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\.\*`)
|
||||
dbPattern := regexp.MustCompile(dbPatternStr)
|
||||
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\s*\.\s*\*`)
|
||||
allPrivilegesPattern := regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`)
|
||||
|
||||
for rows.Next() {
|
||||
var grant string
|
||||
@@ -498,23 +507,26 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
|
||||
return "", fmt.Errorf("failed to scan grant: %w", err)
|
||||
}
|
||||
|
||||
if regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`).MatchString(grant) {
|
||||
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
|
||||
hasAllPrivileges = true
|
||||
}
|
||||
isRelevantGrant := globalPattern.MatchString(grant) || dbPattern.MatchString(grant)
|
||||
|
||||
if allPrivilegesPattern.MatchString(grant) && isRelevantGrant {
|
||||
hasAllPrivileges = true
|
||||
}
|
||||
|
||||
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
|
||||
if isRelevantGrant {
|
||||
for _, priv := range backupPrivileges {
|
||||
if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) {
|
||||
privPattern := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(priv) + `\b`)
|
||||
if privPattern.MatchString(grant) {
|
||||
detectedPrivileges[priv] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if globalPattern.MatchString(grant) &&
|
||||
regexp.MustCompile(`(?i)\bPROCESS\b`).MatchString(grant) {
|
||||
hasProcess = true
|
||||
if globalPattern.MatchString(grant) {
|
||||
processPattern := regexp.MustCompile(`(?i)\bPROCESS\b`)
|
||||
if processPattern.MatchString(grant) {
|
||||
hasProcess = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -518,6 +518,268 @@ func Test_ReadOnlyUser_CannotDropOrAlterTables(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_TestConnection_DatabaseSpecificPrivilegesWithGlobalProcess_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MysqlVersion
|
||||
port string
|
||||
}{
|
||||
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port},
|
||||
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port},
|
||||
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port},
|
||||
{"MySQL 9", tools.MysqlVersion9, env.TestMysql90Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMysqlContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`DROP TABLE IF EXISTS privilege_test`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`CREATE TABLE privilege_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`INSERT INTO privilege_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
specificUsername := fmt.Sprintf("specific_%s", uuid.New().String()[:8])
|
||||
specificPassword := "specificpass123"
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
specificUsername,
|
||||
specificPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"GRANT SELECT, SHOW VIEW ON %s.* TO '%s'@'%%'",
|
||||
container.Database,
|
||||
specificUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"GRANT PROCESS ON *.* TO '%s'@'%%'",
|
||||
specificUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(
|
||||
fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", specificUsername),
|
||||
)
|
||||
}()
|
||||
|
||||
mysqlModel := &MysqlDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: specificUsername,
|
||||
Password: specificPassword,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mysqlModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_TestConnection_DatabaseWithUnderscores_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToMysqlContainer(t, env.TestMysql80Port, tools.MysqlVersion80)
|
||||
defer container.DB.Close()
|
||||
|
||||
underscoreDbName := "test_db_name"
|
||||
|
||||
_, err := container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", underscoreDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName))
|
||||
}()
|
||||
|
||||
underscoreDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
container.Username, container.Password, container.Host, container.Port, underscoreDbName)
|
||||
underscoreDB, err := sqlx.Connect("mysql", underscoreDSN)
|
||||
assert.NoError(t, err)
|
||||
defer underscoreDB.Close()
|
||||
|
||||
_, err = underscoreDB.Exec(`
|
||||
CREATE TABLE underscore_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(`INSERT INTO underscore_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
underscoreUsername := fmt.Sprintf("under_%s", uuid.New().String()[:8])
|
||||
underscorePassword := "underscorepass123"
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
underscoreUsername,
|
||||
underscorePassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"GRANT SELECT, SHOW VIEW ON `%s`.* TO '%s'@'%%'",
|
||||
underscoreDbName,
|
||||
underscoreUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = underscoreDB.Exec(fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", underscoreUsername))
|
||||
}()
|
||||
|
||||
mysqlModel := &MysqlDatabase{
|
||||
Version: tools.MysqlVersion80,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: underscoreUsername,
|
||||
Password: underscorePassword,
|
||||
Database: &underscoreDbName,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mysqlModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_TestConnection_DatabaseWithUnderscoresAndAllPrivileges_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MysqlVersion
|
||||
port string
|
||||
}{
|
||||
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port},
|
||||
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port},
|
||||
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port},
|
||||
{"MySQL 9", tools.MysqlVersion9, env.TestMysql90Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMysqlContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
underscoreDbName := "test_all_db"
|
||||
|
||||
_, err := container.DB.Exec(
|
||||
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", underscoreDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(
|
||||
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
|
||||
)
|
||||
}()
|
||||
|
||||
underscoreDSN := fmt.Sprintf(
|
||||
"%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
container.Username,
|
||||
container.Password,
|
||||
container.Host,
|
||||
container.Port,
|
||||
underscoreDbName,
|
||||
)
|
||||
underscoreDB, err := sqlx.Connect("mysql", underscoreDSN)
|
||||
assert.NoError(t, err)
|
||||
defer underscoreDB.Close()
|
||||
|
||||
_, err = underscoreDB.Exec(`
|
||||
CREATE TABLE all_priv_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(`INSERT INTO all_priv_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
allPrivUsername := fmt.Sprintf("allpriv_%s", uuid.New().String()[:8])
|
||||
allPrivPassword := "allprivpass123"
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
allPrivUsername,
|
||||
allPrivPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"GRANT ALL PRIVILEGES ON `%s`.* TO '%s'@'%%'",
|
||||
underscoreDbName,
|
||||
allPrivUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(
|
||||
fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", allPrivUsername),
|
||||
)
|
||||
}()
|
||||
|
||||
mysqlModel := &MysqlDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: allPrivUsername,
|
||||
Password: allPrivPassword,
|
||||
Database: &underscoreDbName,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mysqlModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, mysqlModel.Privileges)
|
||||
assert.Contains(t, mysqlModel.Privileges, "SELECT")
|
||||
assert.Contains(t, mysqlModel.Privileges, "SHOW VIEW")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type MysqlContainer struct {
|
||||
Host string
|
||||
Port int
|
||||
@@ -538,7 +800,7 @@ func connectToMysqlContainer(
|
||||
}
|
||||
|
||||
dbName := "testdb"
|
||||
host := "127.0.0.1"
|
||||
host := config.GetEnv().TestLocalhost
|
||||
username := "root"
|
||||
password := "rootpassword"
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -85,6 +86,42 @@ func (p *PostgresqlDatabase) Validate() error {
|
||||
return errors.New("cpu count must be greater than 0")
|
||||
}
|
||||
|
||||
// Prevent Databasus from backing up itself
|
||||
// Databasus runs an internal PostgreSQL instance that should not be backed up through the UI
|
||||
// because it would expose internal metadata to non-system administrators.
|
||||
// To properly backup Databasus, see: https://databasus.com/faq#backup-databasus
|
||||
if p.Database != nil && *p.Database != "" {
|
||||
localhostHosts := []string{
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
"172.17.0.1",
|
||||
"host.docker.internal",
|
||||
"::1", // IPv6 loopback (equivalent to 127.0.0.1)
|
||||
"::", // IPv6 all interfaces (equivalent to 0.0.0.0)
|
||||
"0.0.0.0", // IPv4 all interfaces
|
||||
}
|
||||
|
||||
isLocalhost := false
|
||||
|
||||
for _, host := range localhostHosts {
|
||||
if strings.EqualFold(p.Host, host) {
|
||||
isLocalhost = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Also check if the host is in the entire 127.0.0.0/8 loopback range
|
||||
if strings.HasPrefix(p.Host, "127.") {
|
||||
isLocalhost = true
|
||||
}
|
||||
|
||||
if isLocalhost && strings.EqualFold(*p.Database, "databasus") {
|
||||
return errors.New(
|
||||
"backing up Databasus internal database is not allowed. To backup Databasus itself, see https://databasus.com/faq#backup-databasus",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -358,10 +395,13 @@ func (p *PostgresqlDatabase) IsUserReadOnly(
|
||||
//
|
||||
// This method performs the following operations atomically in a single transaction:
|
||||
// 1. Creates a PostgreSQL user with a UUID-based password
|
||||
// 2. Grants CONNECT privilege on the database
|
||||
// 3. Grants USAGE on all non-system schemas
|
||||
// 4. Grants SELECT on all existing tables and sequences
|
||||
// 5. Sets default privileges for future tables and sequences
|
||||
// 2. Revokes CREATE privilege on public schema from PUBLIC role
|
||||
// 3. Grants CONNECT privilege on the database
|
||||
// 4. Discovers all user-created schemas
|
||||
// 5. Grants USAGE on all non-system schemas
|
||||
// 6. Grants SELECT on all existing tables and sequences
|
||||
// 7. Sets default privileges for future tables and sequences
|
||||
// 8. Verifies user creation before committing
|
||||
//
|
||||
// Security features:
|
||||
// - Username format: "databasus-{8-char-uuid}" for uniqueness
|
||||
@@ -451,33 +491,56 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
return "", "", fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
// Step 1.5: Revoke CREATE privilege from PUBLIC role on public schema
|
||||
// Step 2: Check if public schema exists and revoke CREATE privilege if it does
|
||||
// This is necessary because all PostgreSQL users inherit CREATE privilege on the
|
||||
// public schema through the PUBLIC role. This is a one-time operation that affects
|
||||
// the entire database, making it more secure by default.
|
||||
// Note: This only affects the public schema; other schemas are unaffected.
|
||||
_, err = tx.Exec(ctx, `REVOKE CREATE ON SCHEMA public FROM PUBLIC`)
|
||||
if err != nil {
|
||||
logger.Error("Failed to revoke CREATE on public from PUBLIC", "error", err)
|
||||
if !strings.Contains(err.Error(), "schema \"public\" does not exist") &&
|
||||
!strings.Contains(err.Error(), "permission denied") {
|
||||
return "", "", fmt.Errorf("failed to revoke CREATE from PUBLIC: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Now revoke from the specific user as well (belt and suspenders)
|
||||
_, err = tx.Exec(ctx, fmt.Sprintf(`REVOKE CREATE ON SCHEMA public FROM "%s"`, baseUsername))
|
||||
if err != nil {
|
||||
logger.Error(
|
||||
"Failed to revoke CREATE on public schema from user",
|
||||
"error",
|
||||
err,
|
||||
"username",
|
||||
baseUsername,
|
||||
var publicSchemaExists bool
|
||||
err = tx.QueryRow(ctx, `
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM information_schema.schemata
|
||||
WHERE schema_name = 'public'
|
||||
)
|
||||
`).Scan(&publicSchemaExists)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to check if public schema exists: %w", err)
|
||||
}
|
||||
|
||||
// Step 2: Grant database connection privilege and revoke TEMP
|
||||
if publicSchemaExists {
|
||||
// Revoke CREATE from PUBLIC role (affects all users)
|
||||
_, err = tx.Exec(ctx, `REVOKE CREATE ON SCHEMA public FROM PUBLIC`)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "permission denied") {
|
||||
logger.Warn(
|
||||
"Failed to revoke CREATE on public from PUBLIC (permission denied)",
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
} else {
|
||||
return "", "", fmt.Errorf("failed to revoke CREATE from PUBLIC on existing public schema: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Now revoke from the specific user as well (belt and suspenders)
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(`REVOKE CREATE ON SCHEMA public FROM "%s"`, baseUsername),
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn(
|
||||
"Failed to revoke CREATE on public schema from user",
|
||||
"error",
|
||||
err,
|
||||
"username",
|
||||
baseUsername,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
logger.Info("Public schema does not exist, skipping CREATE privilege revocation")
|
||||
}
|
||||
|
||||
// Step 3: Grant database connection privilege and revoke TEMP
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(`GRANT CONNECT ON DATABASE "%s" TO "%s"`, *p.Database, baseUsername),
|
||||
@@ -501,12 +564,23 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
logger.Warn("Failed to revoke TEMP privilege", "error", err, "username", baseUsername)
|
||||
}
|
||||
|
||||
// Step 3: Discover all user-created schemas
|
||||
rows, err := tx.Query(ctx, `
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
|
||||
`)
|
||||
// Step 4: Discover schemas to grant privileges on
|
||||
// If IncludeSchemas is specified, only use those schemas; otherwise use all non-system schemas
|
||||
var rows pgx.Rows
|
||||
if len(p.IncludeSchemas) > 0 {
|
||||
rows, err = tx.Query(ctx, `
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
|
||||
AND schema_name = ANY($1::text[])
|
||||
`, p.IncludeSchemas)
|
||||
} else {
|
||||
rows, err = tx.Query(ctx, `
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
|
||||
`)
|
||||
}
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to get schemas: %w", err)
|
||||
}
|
||||
@@ -526,7 +600,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
return "", "", fmt.Errorf("error iterating schemas: %w", err)
|
||||
}
|
||||
|
||||
// Step 4: Grant USAGE on each schema and explicitly prevent CREATE
|
||||
// Step 5: Grant USAGE on each schema and explicitly prevent CREATE
|
||||
for _, schema := range schemas {
|
||||
// Revoke CREATE specifically (handles inheritance from PUBLIC role)
|
||||
_, err = tx.Exec(
|
||||
@@ -555,51 +629,198 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
}
|
||||
}
|
||||
|
||||
// Step 5: Grant SELECT on ALL existing tables and sequences
|
||||
grantSelectSQL := fmt.Sprintf(`
|
||||
DO $$
|
||||
DECLARE
|
||||
schema_rec RECORD;
|
||||
BEGIN
|
||||
FOR schema_rec IN
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
|
||||
LOOP
|
||||
EXECUTE format('GRANT SELECT ON ALL TABLES IN SCHEMA %%I TO "%s"', schema_rec.schema_name);
|
||||
EXECUTE format('GRANT SELECT ON ALL SEQUENCES IN SCHEMA %%I TO "%s"', schema_rec.schema_name);
|
||||
END LOOP;
|
||||
END $$;
|
||||
`, baseUsername, baseUsername)
|
||||
// Step 6: Grant SELECT on ALL existing tables and sequences
|
||||
// Use the already-filtered schemas list from Step 4
|
||||
for _, schema := range schemas {
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(
|
||||
`GRANT SELECT ON ALL TABLES IN SCHEMA "%s" TO "%s"`,
|
||||
schema,
|
||||
baseUsername,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf(
|
||||
"failed to grant select on tables in schema %s: %w",
|
||||
schema,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(ctx, grantSelectSQL)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to grant select on tables: %w", err)
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(
|
||||
`GRANT SELECT ON ALL SEQUENCES IN SCHEMA "%s" TO "%s"`,
|
||||
schema,
|
||||
baseUsername,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf(
|
||||
"failed to grant select on sequences in schema %s: %w",
|
||||
schema,
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 6: Set default privileges for FUTURE tables and sequences
|
||||
defaultPrivilegesSQL := fmt.Sprintf(`
|
||||
DO $$
|
||||
DECLARE
|
||||
schema_rec RECORD;
|
||||
BEGIN
|
||||
FOR schema_rec IN
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
|
||||
LOOP
|
||||
EXECUTE format('ALTER DEFAULT PRIVILEGES IN SCHEMA %%I GRANT SELECT ON TABLES TO "%s"', schema_rec.schema_name);
|
||||
EXECUTE format('ALTER DEFAULT PRIVILEGES IN SCHEMA %%I GRANT SELECT ON SEQUENCES TO "%s"', schema_rec.schema_name);
|
||||
END LOOP;
|
||||
END $$;
|
||||
`, baseUsername, baseUsername)
|
||||
// Step 7: Set default privileges for FUTURE tables and sequences
|
||||
// First, set default privileges for objects created by the current user
|
||||
// Use the already-filtered schemas list from Step 4
|
||||
for _, schema := range schemas {
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(
|
||||
`ALTER DEFAULT PRIVILEGES IN SCHEMA "%s" GRANT SELECT ON TABLES TO "%s"`,
|
||||
schema,
|
||||
baseUsername,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf(
|
||||
"failed to set default privileges for tables in schema %s: %w",
|
||||
schema,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(ctx, defaultPrivilegesSQL)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to set default privileges: %w", err)
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(
|
||||
`ALTER DEFAULT PRIVILEGES IN SCHEMA "%s" GRANT SELECT ON SEQUENCES TO "%s"`,
|
||||
schema,
|
||||
baseUsername,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf(
|
||||
"failed to set default privileges for sequences in schema %s: %w",
|
||||
schema,
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 7: Verify user creation before committing
|
||||
// Step 8: Discover all roles that own objects in each schema
|
||||
// This is needed because ALTER DEFAULT PRIVILEGES only applies to objects created by the current role.
|
||||
// To handle tables created by OTHER users (like the GitHub issue with partitioned tables),
|
||||
// we need to set "ALTER DEFAULT PRIVILEGES FOR ROLE <owner>" for each object owner.
|
||||
// Filter by IncludeSchemas if specified.
|
||||
type SchemaOwner struct {
|
||||
SchemaName string
|
||||
RoleName string
|
||||
}
|
||||
|
||||
var ownerRows pgx.Rows
|
||||
if len(p.IncludeSchemas) > 0 {
|
||||
ownerRows, err = tx.Query(ctx, `
|
||||
SELECT DISTINCT n.nspname as schema_name, pg_get_userbyid(c.relowner) as role_name
|
||||
FROM pg_class c
|
||||
JOIN pg_namespace n ON c.relnamespace = n.oid
|
||||
WHERE n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
|
||||
AND n.nspname = ANY($1::text[])
|
||||
AND c.relkind IN ('r', 'p', 'v', 'm', 'f')
|
||||
AND pg_get_userbyid(c.relowner) != current_user
|
||||
ORDER BY n.nspname, role_name
|
||||
`, p.IncludeSchemas)
|
||||
} else {
|
||||
ownerRows, err = tx.Query(ctx, `
|
||||
SELECT DISTINCT n.nspname as schema_name, pg_get_userbyid(c.relowner) as role_name
|
||||
FROM pg_class c
|
||||
JOIN pg_namespace n ON c.relnamespace = n.oid
|
||||
WHERE n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
|
||||
AND c.relkind IN ('r', 'p', 'v', 'm', 'f')
|
||||
AND pg_get_userbyid(c.relowner) != current_user
|
||||
ORDER BY n.nspname, role_name
|
||||
`)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// Log warning but continue - this is a best-effort enhancement
|
||||
logger.Warn("Failed to query object owners for default privileges", "error", err)
|
||||
} else {
|
||||
var schemaOwners []SchemaOwner
|
||||
for ownerRows.Next() {
|
||||
var so SchemaOwner
|
||||
if err := ownerRows.Scan(&so.SchemaName, &so.RoleName); err != nil {
|
||||
ownerRows.Close()
|
||||
logger.Warn("Failed to scan schema owner", "error", err)
|
||||
break
|
||||
}
|
||||
schemaOwners = append(schemaOwners, so)
|
||||
}
|
||||
ownerRows.Close()
|
||||
|
||||
if err := ownerRows.Err(); err != nil {
|
||||
logger.Warn("Error iterating schema owners", "error", err)
|
||||
}
|
||||
|
||||
// Step 9: Set default privileges FOR ROLE for each object owner
|
||||
// Note: This may fail for some roles due to permission issues (e.g., roles owned by other superusers)
|
||||
// We log warnings but continue - user creation should succeed even if some roles can't be configured
|
||||
for _, so := range schemaOwners {
|
||||
// Try to set default privileges for tables
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(
|
||||
`ALTER DEFAULT PRIVILEGES FOR ROLE "%s" IN SCHEMA "%s" GRANT SELECT ON TABLES TO "%s"`,
|
||||
so.RoleName,
|
||||
so.SchemaName,
|
||||
baseUsername,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn(
|
||||
"Failed to set default privileges for role (tables)",
|
||||
"error",
|
||||
err,
|
||||
"role",
|
||||
so.RoleName,
|
||||
"schema",
|
||||
so.SchemaName,
|
||||
"readonly_user",
|
||||
baseUsername,
|
||||
)
|
||||
}
|
||||
|
||||
// Try to set default privileges for sequences
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(
|
||||
`ALTER DEFAULT PRIVILEGES FOR ROLE "%s" IN SCHEMA "%s" GRANT SELECT ON SEQUENCES TO "%s"`,
|
||||
so.RoleName,
|
||||
so.SchemaName,
|
||||
baseUsername,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn(
|
||||
"Failed to set default privileges for role (sequences)",
|
||||
"error",
|
||||
err,
|
||||
"role",
|
||||
so.RoleName,
|
||||
"schema",
|
||||
so.SchemaName,
|
||||
"readonly_user",
|
||||
baseUsername,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if len(schemaOwners) > 0 {
|
||||
logger.Info(
|
||||
"Set default privileges for existing object owners",
|
||||
"readonly_user",
|
||||
baseUsername,
|
||||
"owner_count",
|
||||
len(schemaOwners),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 10: Verify user creation before committing
|
||||
var verifyUsername string
|
||||
err = tx.QueryRow(ctx, fmt.Sprintf(`SELECT rolname FROM pg_roles WHERE rolname = '%s'`, baseUsername)).
|
||||
Scan(&verifyUsername)
|
||||
@@ -671,7 +892,7 @@ func testSingleDatabaseConnection(
|
||||
postgresDb.Version = detectedVersion
|
||||
|
||||
// Verify user has sufficient permissions for backup operations
|
||||
if err := checkBackupPermissions(ctx, conn, *postgresDb.Database); err != nil {
|
||||
if err := checkBackupPermissions(ctx, conn, postgresDb.IncludeSchemas); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -707,7 +928,12 @@ func detectDatabaseVersion(ctx context.Context, conn *pgx.Conn) (tools.Postgresq
|
||||
|
||||
// checkBackupPermissions verifies the user has sufficient privileges for pg_dump backup.
|
||||
// Required privileges: CONNECT on database, USAGE on schemas, SELECT on tables.
|
||||
func checkBackupPermissions(ctx context.Context, conn *pgx.Conn, dbName string) error {
|
||||
// If includeSchemas is specified, only checks permissions on those schemas.
|
||||
func checkBackupPermissions(
|
||||
ctx context.Context,
|
||||
conn *pgx.Conn,
|
||||
includeSchemas []string,
|
||||
) error {
|
||||
var missingPrivileges []string
|
||||
|
||||
// Check CONNECT privilege on database
|
||||
@@ -723,14 +949,29 @@ func checkBackupPermissions(ctx context.Context, conn *pgx.Conn, dbName string)
|
||||
|
||||
// Check USAGE privilege on at least one non-system schema
|
||||
var schemaCount int
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_namespace n
|
||||
WHERE has_schema_privilege(current_user, n.nspname, 'USAGE')
|
||||
AND n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
|
||||
AND n.nspname NOT LIKE 'pg_temp_%'
|
||||
AND n.nspname NOT LIKE 'pg_toast_temp_%'
|
||||
`).Scan(&schemaCount)
|
||||
if len(includeSchemas) > 0 {
|
||||
// Check only the specified schemas
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_namespace n
|
||||
WHERE has_schema_privilege(current_user, n.nspname, 'USAGE')
|
||||
AND n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
|
||||
AND n.nspname NOT LIKE 'pg_temp_%'
|
||||
AND n.nspname NOT LIKE 'pg_toast_temp_%'
|
||||
AND n.nspname = ANY($1::text[])
|
||||
`, includeSchemas).Scan(&schemaCount)
|
||||
} else {
|
||||
// Check all non-system schemas
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_namespace n
|
||||
WHERE has_schema_privilege(current_user, n.nspname, 'USAGE')
|
||||
AND n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
|
||||
AND n.nspname NOT LIKE 'pg_temp_%'
|
||||
AND n.nspname NOT LIKE 'pg_toast_temp_%'
|
||||
`).Scan(&schemaCount)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot check schema privileges: %w", err)
|
||||
}
|
||||
@@ -741,11 +982,28 @@ func checkBackupPermissions(ctx context.Context, conn *pgx.Conn, dbName string)
|
||||
// Check SELECT privilege on at least one table (if tables exist)
|
||||
// Use pg_tables from pg_catalog which shows all tables regardless of user privileges
|
||||
var tableCount int
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_catalog.pg_tables t
|
||||
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
|
||||
`).Scan(&tableCount)
|
||||
|
||||
if len(includeSchemas) > 0 {
|
||||
// Check only tables in the specified schemas
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_catalog.pg_tables t
|
||||
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
|
||||
AND t.schemaname NOT LIKE 'pg_temp_%'
|
||||
AND t.schemaname NOT LIKE 'pg_toast_temp_%'
|
||||
AND t.schemaname = ANY($1::text[])
|
||||
`, includeSchemas).Scan(&tableCount)
|
||||
} else {
|
||||
// Check all tables in non-system schemas
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_catalog.pg_tables t
|
||||
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
|
||||
AND t.schemaname NOT LIKE 'pg_temp_%'
|
||||
AND t.schemaname NOT LIKE 'pg_toast_temp_%'
|
||||
`).Scan(&tableCount)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot check table count: %w", err)
|
||||
}
|
||||
@@ -753,14 +1011,40 @@ func checkBackupPermissions(ctx context.Context, conn *pgx.Conn, dbName string)
|
||||
if tableCount > 0 {
|
||||
// Check if user has SELECT on at least one of these tables
|
||||
var selectableTableCount int
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_catalog.pg_tables t
|
||||
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
|
||||
AND has_table_privilege(current_user, quote_ident(t.schemaname) || '.' || quote_ident(t.tablename), 'SELECT')
|
||||
`).Scan(&selectableTableCount)
|
||||
|
||||
if len(includeSchemas) > 0 {
|
||||
// Check only tables in the specified schemas
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_catalog.pg_tables t
|
||||
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
|
||||
AND t.schemaname NOT LIKE 'pg_temp_%'
|
||||
AND t.schemaname NOT LIKE 'pg_toast_temp_%'
|
||||
AND t.schemaname = ANY($1::text[])
|
||||
AND has_table_privilege(current_user, quote_ident(t.schemaname) || '.' || quote_ident(t.tablename), 'SELECT')
|
||||
`, includeSchemas).Scan(&selectableTableCount)
|
||||
} else {
|
||||
// Check all tables in non-system schemas
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_catalog.pg_tables t
|
||||
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
|
||||
AND t.schemaname NOT LIKE 'pg_temp_%'
|
||||
AND t.schemaname NOT LIKE 'pg_toast_temp_%'
|
||||
AND has_table_privilege(current_user, quote_ident(t.schemaname) || '.' || quote_ident(t.tablename), 'SELECT')
|
||||
`).Scan(&selectableTableCount)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot check SELECT privileges: %w", err)
|
||||
// If the user doesn't have USAGE on the schema, has_table_privilege will fail
|
||||
// with "permission denied for schema". This means they definitely don't have
|
||||
// SELECT privileges, so treat this as missing permissions rather than an error.
|
||||
var pgErr *pgconn.PgError
|
||||
if errors.As(err, &pgErr) && pgErr.Code == "42501" { // insufficient_privilege
|
||||
selectableTableCount = 0
|
||||
} else {
|
||||
return fmt.Errorf("cannot check SELECT privileges: %w", err)
|
||||
}
|
||||
}
|
||||
if selectableTableCount == 0 {
|
||||
missingPrivileges = append(missingPrivileges, "SELECT on tables")
|
||||
|
||||
@@ -599,6 +599,10 @@ func Test_CreateReadOnlyUser_DatabaseNameWithDash_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_Supabase_UserCanReadButNotWrite(t *testing.T) {
|
||||
if config.GetEnv().IsSkipExternalResourcesTests {
|
||||
t.Skip("Skipping Supabase test: IS_SKIP_EXTERNAL_RESOURCES_TESTS is true")
|
||||
}
|
||||
|
||||
env := config.GetEnv()
|
||||
|
||||
if env.TestSupabaseHost == "" {
|
||||
@@ -705,6 +709,607 @@ func Test_CreateReadOnlyUser_Supabase_UserCanReadButNotWrite(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_WithPublicSchema_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
port string
|
||||
}{
|
||||
{"PostgreSQL 12", "12", env.TestPostgres12Port},
|
||||
{"PostgreSQL 13", "13", env.TestPostgres13Port},
|
||||
{"PostgreSQL 14", "14", env.TestPostgres14Port},
|
||||
{"PostgreSQL 15", "15", env.TestPostgres15Port},
|
||||
{"PostgreSQL 16", "16", env.TestPostgres16Port},
|
||||
{"PostgreSQL 17", "17", env.TestPostgres17Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToPostgresContainer(t, tc.port)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`
|
||||
DROP TABLE IF EXISTS public_schema_test CASCADE;
|
||||
CREATE TABLE public_schema_test (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO public_schema_test (data) VALUES ('test1'), ('test2');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, username)
|
||||
assert.NotEmpty(t, password)
|
||||
assert.True(t, strings.HasPrefix(username, "databasus-"))
|
||||
|
||||
readOnlyModel := &PostgresqlDatabase{
|
||||
Version: pgModel.Version,
|
||||
Host: pgModel.Host,
|
||||
Port: pgModel.Port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: pgModel.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
|
||||
ctx,
|
||||
logger,
|
||||
nil,
|
||||
uuid.New(),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, isReadOnly, "User should be read-only")
|
||||
assert.Empty(t, privileges, "Read-only user should have no write privileges")
|
||||
|
||||
readOnlyDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
username,
|
||||
password,
|
||||
container.Database,
|
||||
)
|
||||
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var count int
|
||||
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM public_schema_test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
_, err = readOnlyConn.Exec(
|
||||
"INSERT INTO public_schema_test (data) VALUES ('should-fail')",
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("CREATE TABLE public.hack_table (id INT)")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to drop owned objects: %v", err)
|
||||
}
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_WithoutPublicSchema_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
port string
|
||||
}{
|
||||
{"PostgreSQL 12", "12", env.TestPostgres12Port},
|
||||
{"PostgreSQL 13", "13", env.TestPostgres13Port},
|
||||
{"PostgreSQL 14", "14", env.TestPostgres14Port},
|
||||
{"PostgreSQL 15", "15", env.TestPostgres15Port},
|
||||
{"PostgreSQL 16", "16", env.TestPostgres16Port},
|
||||
{"PostgreSQL 17", "17", env.TestPostgres17Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToPostgresContainer(t, tc.port)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`
|
||||
DROP SCHEMA IF EXISTS public CASCADE;
|
||||
DROP SCHEMA IF EXISTS app_schema CASCADE;
|
||||
DROP SCHEMA IF EXISTS data_schema CASCADE;
|
||||
CREATE SCHEMA app_schema;
|
||||
CREATE SCHEMA data_schema;
|
||||
CREATE TABLE app_schema.users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
username TEXT NOT NULL
|
||||
);
|
||||
CREATE TABLE data_schema.records (
|
||||
id SERIAL PRIMARY KEY,
|
||||
info TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO app_schema.users (username) VALUES ('user1'), ('user2');
|
||||
INSERT INTO data_schema.records (info) VALUES ('record1'), ('record2');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err, "CreateReadOnlyUser should succeed without public schema")
|
||||
assert.NotEmpty(t, username)
|
||||
assert.NotEmpty(t, password)
|
||||
assert.True(t, strings.HasPrefix(username, "databasus-"))
|
||||
|
||||
readOnlyModel := &PostgresqlDatabase{
|
||||
Version: pgModel.Version,
|
||||
Host: pgModel.Host,
|
||||
Port: pgModel.Port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: pgModel.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
|
||||
ctx,
|
||||
logger,
|
||||
nil,
|
||||
uuid.New(),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, isReadOnly, "User should be read-only")
|
||||
assert.Empty(t, privileges, "Read-only user should have no write privileges")
|
||||
|
||||
readOnlyDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
username,
|
||||
password,
|
||||
container.Database,
|
||||
)
|
||||
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var userCount int
|
||||
err = readOnlyConn.Get(&userCount, "SELECT COUNT(*) FROM app_schema.users")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, userCount)
|
||||
|
||||
var recordCount int
|
||||
err = readOnlyConn.Get(&recordCount, "SELECT COUNT(*) FROM data_schema.records")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, recordCount)
|
||||
|
||||
_, err = readOnlyConn.Exec(
|
||||
"INSERT INTO app_schema.users (username) VALUES ('should-fail')",
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("CREATE TABLE app_schema.hack_table (id INT)")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("CREATE TABLE data_schema.hack_table (id INT)")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to drop owned objects: %v", err)
|
||||
}
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`
|
||||
DROP SCHEMA IF EXISTS app_schema CASCADE;
|
||||
DROP SCHEMA IF EXISTS data_schema CASCADE;
|
||||
CREATE SCHEMA IF NOT EXISTS public;
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_PublicSchemaExistsButNoPermissions_ReturnsError(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
port string
|
||||
}{
|
||||
{"PostgreSQL 12", "12", env.TestPostgres12Port},
|
||||
{"PostgreSQL 13", "13", env.TestPostgres13Port},
|
||||
{"PostgreSQL 14", "14", env.TestPostgres14Port},
|
||||
{"PostgreSQL 15", "15", env.TestPostgres15Port},
|
||||
{"PostgreSQL 16", "16", env.TestPostgres16Port},
|
||||
{"PostgreSQL 17", "17", env.TestPostgres17Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToPostgresContainer(t, tc.port)
|
||||
defer container.DB.Close()
|
||||
|
||||
limitedAdminUsername := fmt.Sprintf("limited_admin_%s", uuid.New().String()[:8])
|
||||
limitedAdminPassword := "limited_password_123"
|
||||
|
||||
_, err := container.DB.Exec(`
|
||||
CREATE SCHEMA IF NOT EXISTS public;
|
||||
DROP TABLE IF EXISTS public.permission_test_table CASCADE;
|
||||
CREATE TABLE public.permission_test_table (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO public.permission_test_table (data) VALUES ('test1');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`GRANT CREATE ON SCHEMA public TO PUBLIC`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN CREATEROLE`,
|
||||
limitedAdminUsername,
|
||||
limitedAdminPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
|
||||
container.Database,
|
||||
limitedAdminUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(
|
||||
fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, limitedAdminUsername),
|
||||
)
|
||||
_, _ = container.DB.Exec(
|
||||
fmt.Sprintf(`DROP USER IF EXISTS "%s"`, limitedAdminUsername),
|
||||
)
|
||||
_, _ = container.DB.Exec(`REVOKE CREATE ON SCHEMA public FROM PUBLIC`)
|
||||
}()
|
||||
|
||||
limitedAdminDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
limitedAdminUsername,
|
||||
limitedAdminPassword,
|
||||
container.Database,
|
||||
)
|
||||
limitedAdminConn, err := sqlx.Connect("postgres", limitedAdminDSN)
|
||||
assert.NoError(t, err)
|
||||
defer limitedAdminConn.Close()
|
||||
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Version: tools.GetPostgresqlVersionEnum(tc.version),
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: limitedAdminUsername,
|
||||
Password: limitedAdminPassword,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.Error(
|
||||
t,
|
||||
err,
|
||||
"CreateReadOnlyUser should fail when admin lacks permissions to secure public schema",
|
||||
)
|
||||
if err != nil {
|
||||
errorMsg := err.Error()
|
||||
hasExpectedError := strings.Contains(
|
||||
errorMsg,
|
||||
"failed to revoke CREATE from PUBLIC on existing public schema",
|
||||
) ||
|
||||
strings.Contains(errorMsg, "permission denied for schema public") ||
|
||||
strings.Contains(errorMsg, "failed to grant")
|
||||
assert.True(
|
||||
t,
|
||||
hasExpectedError,
|
||||
"Error should indicate permission issues with public schema, got: %s",
|
||||
errorMsg,
|
||||
)
|
||||
}
|
||||
assert.Empty(t, username)
|
||||
assert.Empty(t, password)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Validate_WhenLocalhostAndDatabasus_ReturnsError(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
username string
|
||||
database string
|
||||
}{
|
||||
{
|
||||
name: "localhost with databasus db",
|
||||
host: "localhost",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1 with databasus db",
|
||||
host: "127.0.0.1",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "172.17.0.1 with databasus db",
|
||||
host: "172.17.0.1",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "host.docker.internal with databasus db",
|
||||
host: "host.docker.internal",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "LOCALHOST (uppercase) with DATABASUS db",
|
||||
host: "LOCALHOST",
|
||||
username: "POSTGRES",
|
||||
database: "DATABASUS",
|
||||
},
|
||||
{
|
||||
name: "LocalHost (mixed case) with DataBasus db",
|
||||
host: "LocalHost",
|
||||
username: "anyuser",
|
||||
database: "DataBasus",
|
||||
},
|
||||
{
|
||||
name: "localhost with databasus and any username",
|
||||
host: "localhost",
|
||||
username: "myuser",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "::1 (IPv6 loopback) with databasus db",
|
||||
host: "::1",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: ":: (IPv6 all interfaces) with databasus db",
|
||||
host: "::",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "::1 (uppercase) with DATABASUS db",
|
||||
host: "::1",
|
||||
username: "POSTGRES",
|
||||
database: "DATABASUS",
|
||||
},
|
||||
{
|
||||
name: "0.0.0.0 (all IPv4 interfaces) with databasus db",
|
||||
host: "0.0.0.0",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "127.0.0.2 (loopback range) with databasus db",
|
||||
host: "127.0.0.2",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "127.255.255.255 (end of loopback range) with databasus db",
|
||||
host: "127.255.255.255",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Host: tc.host,
|
||||
Port: 5437,
|
||||
Username: tc.username,
|
||||
Password: "somepassword",
|
||||
Database: &tc.database,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
err := pgModel.Validate()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "backing up Databasus internal database is not allowed")
|
||||
assert.Contains(t, err.Error(), "https://databasus.com/faq#backup-databasus")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Validate_WhenNotLocalhostOrNotDatabasus_ValidatesSuccessfully(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
username string
|
||||
database string
|
||||
}{
|
||||
{
|
||||
name: "different host (remote server) with databasus db",
|
||||
host: "192.168.1.100",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "different database name on localhost",
|
||||
host: "localhost",
|
||||
username: "postgres",
|
||||
database: "myapp",
|
||||
},
|
||||
{
|
||||
name: "all different",
|
||||
host: "db.example.com",
|
||||
username: "appuser",
|
||||
database: "production",
|
||||
},
|
||||
{
|
||||
name: "localhost with postgres database",
|
||||
host: "localhost",
|
||||
username: "postgres",
|
||||
database: "postgres",
|
||||
},
|
||||
{
|
||||
name: "remote host with databasus db name (allowed)",
|
||||
host: "db.example.com",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Host: tc.host,
|
||||
Port: 5432,
|
||||
Username: tc.username,
|
||||
Password: "somepassword",
|
||||
Database: &tc.database,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
err := pgModel.Validate()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Validate_WhenDatabaseIsNil_ValidatesSuccessfully(t *testing.T) {
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Host: "localhost",
|
||||
Port: 5437,
|
||||
Username: "postgres",
|
||||
Password: "somepassword",
|
||||
Database: nil,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
err := pgModel.Validate()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenDatabaseIsEmpty_ValidatesSuccessfully(t *testing.T) {
|
||||
emptyDb := ""
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Host: "localhost",
|
||||
Port: 5437,
|
||||
Username: "postgres",
|
||||
Password: "somepassword",
|
||||
Database: &emptyDb,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
err := pgModel.Validate()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenRequiredFieldsMissing_ReturnsError(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
model *PostgresqlDatabase
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "missing host",
|
||||
model: &PostgresqlDatabase{
|
||||
Host: "",
|
||||
Port: 5432,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
CpuCount: 1,
|
||||
},
|
||||
expectedError: "host is required",
|
||||
},
|
||||
{
|
||||
name: "missing port",
|
||||
model: &PostgresqlDatabase{
|
||||
Host: "localhost",
|
||||
Port: 0,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
CpuCount: 1,
|
||||
},
|
||||
expectedError: "port is required",
|
||||
},
|
||||
{
|
||||
name: "missing username",
|
||||
model: &PostgresqlDatabase{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "",
|
||||
Password: "pass",
|
||||
CpuCount: 1,
|
||||
},
|
||||
expectedError: "username is required",
|
||||
},
|
||||
{
|
||||
name: "missing password",
|
||||
model: &PostgresqlDatabase{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "user",
|
||||
Password: "",
|
||||
CpuCount: 1,
|
||||
},
|
||||
expectedError: "password is required",
|
||||
},
|
||||
{
|
||||
name: "invalid cpu count",
|
||||
model: &PostgresqlDatabase{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
CpuCount: 0,
|
||||
},
|
||||
expectedError: "cpu count must be greater than 0",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.model.Validate()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tc.expectedError)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type PostgresContainer struct {
|
||||
Host string
|
||||
Port int
|
||||
@@ -714,11 +1319,351 @@ type PostgresContainer struct {
|
||||
DB *sqlx.DB
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_TablesCreatedByDifferentUser_ReadOnlyUserCanRead(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToPostgresContainer(t, env.TestPostgres16Port)
|
||||
defer container.DB.Close()
|
||||
|
||||
// Step 1: Create a second database user who will create tables
|
||||
userCreatorUsername := fmt.Sprintf("user_creator_%s", uuid.New().String()[:8])
|
||||
userCreatorPassword := "creator_password_123"
|
||||
|
||||
_, err := container.DB.Exec(fmt.Sprintf(
|
||||
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
|
||||
userCreatorUsername,
|
||||
userCreatorPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, userCreatorUsername))
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, userCreatorUsername))
|
||||
}()
|
||||
|
||||
// Step 2: Grant the user_creator privileges to connect and create tables
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
|
||||
container.Database,
|
||||
userCreatorUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT USAGE ON SCHEMA public TO "%s"`,
|
||||
userCreatorUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT CREATE ON SCHEMA public TO "%s"`,
|
||||
userCreatorUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Step 2b: Create an initial table by user_creator so they become an object owner
|
||||
// This is important because our fix discovers existing object owners
|
||||
userCreatorDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
userCreatorUsername,
|
||||
userCreatorPassword,
|
||||
container.Database,
|
||||
)
|
||||
userCreatorConn, err := sqlx.Connect("postgres", userCreatorDSN)
|
||||
assert.NoError(t, err)
|
||||
defer userCreatorConn.Close()
|
||||
|
||||
initialTableName := fmt.Sprintf(
|
||||
"public.initial_table_%s",
|
||||
strings.ReplaceAll(uuid.New().String()[:8], "-", ""),
|
||||
)
|
||||
_, err = userCreatorConn.Exec(fmt.Sprintf(`
|
||||
CREATE TABLE %s (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO %s (data) VALUES ('initial_data');
|
||||
`, initialTableName, initialTableName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP TABLE IF EXISTS %s CASCADE`, initialTableName))
|
||||
}()
|
||||
|
||||
// Step 3: NOW create read-only user via Databasus (as admin)
|
||||
// At this point, user_creator already owns objects, so ALTER DEFAULT PRIVILEGES FOR ROLE should apply
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
readonlyUsername, readonlyPassword, err := pgModel.CreateReadOnlyUser(
|
||||
ctx,
|
||||
logger,
|
||||
nil,
|
||||
uuid.New(),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, readonlyUsername)
|
||||
assert.NotEmpty(t, readonlyPassword)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, readonlyUsername))
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, readonlyUsername))
|
||||
}()
|
||||
|
||||
// Step 4: user_creator creates a NEW table AFTER the read-only user was created
|
||||
// This table should automatically grant SELECT to the read-only user via ALTER DEFAULT PRIVILEGES FOR ROLE
|
||||
tableName := fmt.Sprintf(
|
||||
"public.future_table_%s",
|
||||
strings.ReplaceAll(uuid.New().String()[:8], "-", ""),
|
||||
)
|
||||
_, err = userCreatorConn.Exec(fmt.Sprintf(`
|
||||
CREATE TABLE %s (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO %s (data) VALUES ('test_data_1'), ('test_data_2');
|
||||
`, tableName, tableName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP TABLE IF EXISTS %s CASCADE`, tableName))
|
||||
}()
|
||||
|
||||
// Step 5: Connect as read-only user and verify it can SELECT from the new table
|
||||
readonlyDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
readonlyUsername,
|
||||
readonlyPassword,
|
||||
container.Database,
|
||||
)
|
||||
readonlyConn, err := sqlx.Connect("postgres", readonlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readonlyConn.Close()
|
||||
|
||||
var count int
|
||||
err = readonlyConn.Get(&count, fmt.Sprintf("SELECT COUNT(*) FROM %s", tableName))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(
|
||||
t,
|
||||
2,
|
||||
count,
|
||||
"Read-only user should be able to SELECT from table created by different user",
|
||||
)
|
||||
|
||||
// Step 6: Verify read-only user cannot write to the table
|
||||
_, err = readonlyConn.Exec(
|
||||
fmt.Sprintf("INSERT INTO %s (data) VALUES ('should-fail')", tableName),
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
// Step 7: Verify pg_dump operations (LOCK TABLE) work
|
||||
// pg_dump needs to lock tables in ACCESS SHARE MODE for consistent backup
|
||||
tx, err := readonlyConn.Begin()
|
||||
assert.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
_, err = tx.Exec(fmt.Sprintf("LOCK TABLE %s IN ACCESS SHARE MODE", tableName))
|
||||
assert.NoError(t, err, "Read-only user should be able to LOCK TABLE (needed for pg_dump)")
|
||||
|
||||
err = tx.Commit()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_WithIncludeSchemas_OnlyGrantsAccessToSpecifiedSchemas(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToPostgresContainer(t, env.TestPostgres16Port)
|
||||
defer container.DB.Close()
|
||||
|
||||
// Step 1: Create multiple schemas and tables
|
||||
_, err := container.DB.Exec(`
|
||||
DROP SCHEMA IF EXISTS included_schema CASCADE;
|
||||
DROP SCHEMA IF EXISTS excluded_schema CASCADE;
|
||||
CREATE SCHEMA included_schema;
|
||||
CREATE SCHEMA excluded_schema;
|
||||
|
||||
CREATE TABLE public.public_table (id INT, data TEXT);
|
||||
INSERT INTO public.public_table VALUES (1, 'public_data');
|
||||
|
||||
CREATE TABLE included_schema.included_table (id INT, data TEXT);
|
||||
INSERT INTO included_schema.included_table VALUES (2, 'included_data');
|
||||
|
||||
CREATE TABLE excluded_schema.excluded_table (id INT, data TEXT);
|
||||
INSERT INTO excluded_schema.excluded_table VALUES (3, 'excluded_data');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(`DROP SCHEMA IF EXISTS included_schema CASCADE`)
|
||||
_, _ = container.DB.Exec(`DROP SCHEMA IF EXISTS excluded_schema CASCADE`)
|
||||
}()
|
||||
|
||||
// Step 2: Create a second user who owns tables in both included and excluded schemas
|
||||
userCreatorUsername := fmt.Sprintf("user_creator_%s", uuid.New().String()[:8])
|
||||
userCreatorPassword := "creator_password_123"
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
|
||||
userCreatorUsername,
|
||||
userCreatorPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, userCreatorUsername))
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, userCreatorUsername))
|
||||
}()
|
||||
|
||||
// Grant privileges to user_creator
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
|
||||
container.Database,
|
||||
userCreatorUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
for _, schema := range []string{"public", "included_schema", "excluded_schema"} {
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT USAGE, CREATE ON SCHEMA %s TO "%s"`,
|
||||
schema,
|
||||
userCreatorUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// User_creator creates tables in included and excluded schemas
|
||||
userCreatorDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
userCreatorUsername,
|
||||
userCreatorPassword,
|
||||
container.Database,
|
||||
)
|
||||
userCreatorConn, err := sqlx.Connect("postgres", userCreatorDSN)
|
||||
assert.NoError(t, err)
|
||||
defer userCreatorConn.Close()
|
||||
|
||||
_, err = userCreatorConn.Exec(`
|
||||
CREATE TABLE included_schema.user_table (id INT, data TEXT);
|
||||
INSERT INTO included_schema.user_table VALUES (4, 'user_included_data');
|
||||
|
||||
CREATE TABLE excluded_schema.user_excluded_table (id INT, data TEXT);
|
||||
INSERT INTO excluded_schema.user_excluded_table VALUES (5, 'user_excluded_data');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Step 3: Create read-only user with IncludeSchemas = ["public", "included_schema"]
|
||||
pgModel := createPostgresModel(container)
|
||||
pgModel.IncludeSchemas = []string{"public", "included_schema"}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
readonlyUsername, readonlyPassword, err := pgModel.CreateReadOnlyUser(
|
||||
ctx,
|
||||
logger,
|
||||
nil,
|
||||
uuid.New(),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, readonlyUsername)
|
||||
assert.NotEmpty(t, readonlyPassword)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, readonlyUsername))
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, readonlyUsername))
|
||||
}()
|
||||
|
||||
// Step 4: Connect as read-only user
|
||||
readonlyDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
readonlyUsername,
|
||||
readonlyPassword,
|
||||
container.Database,
|
||||
)
|
||||
readonlyConn, err := sqlx.Connect("postgres", readonlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readonlyConn.Close()
|
||||
|
||||
// Step 5: Verify read-only user CAN access included schemas
|
||||
var publicData string
|
||||
err = readonlyConn.Get(&publicData, "SELECT data FROM public.public_table LIMIT 1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "public_data", publicData)
|
||||
|
||||
var includedData string
|
||||
err = readonlyConn.Get(&includedData, "SELECT data FROM included_schema.included_table LIMIT 1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "included_data", includedData)
|
||||
|
||||
var userIncludedData string
|
||||
err = readonlyConn.Get(&userIncludedData, "SELECT data FROM included_schema.user_table LIMIT 1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "user_included_data", userIncludedData)
|
||||
|
||||
// Step 6: Verify read-only user CANNOT access excluded schema
|
||||
var excludedData string
|
||||
err = readonlyConn.Get(&excludedData, "SELECT data FROM excluded_schema.excluded_table LIMIT 1")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
err = readonlyConn.Get(
|
||||
&excludedData,
|
||||
"SELECT data FROM excluded_schema.user_excluded_table LIMIT 1",
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
// Step 7: Verify future tables in included schemas are accessible
|
||||
_, err = userCreatorConn.Exec(`
|
||||
CREATE TABLE included_schema.future_table (id INT, data TEXT);
|
||||
INSERT INTO included_schema.future_table VALUES (6, 'future_data');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var futureData string
|
||||
err = readonlyConn.Get(&futureData, "SELECT data FROM included_schema.future_table LIMIT 1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(
|
||||
t,
|
||||
"future_data",
|
||||
futureData,
|
||||
"Read-only user should access future tables in included schemas via ALTER DEFAULT PRIVILEGES FOR ROLE",
|
||||
)
|
||||
|
||||
// Step 8: Verify future tables in excluded schema are NOT accessible
|
||||
_, err = userCreatorConn.Exec(`
|
||||
CREATE TABLE excluded_schema.future_excluded_table (id INT, data TEXT);
|
||||
INSERT INTO excluded_schema.future_excluded_table VALUES (7, 'future_excluded_data');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var futureExcludedData string
|
||||
err = readonlyConn.Get(
|
||||
&futureExcludedData,
|
||||
"SELECT data FROM excluded_schema.future_excluded_table LIMIT 1",
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(
|
||||
t,
|
||||
err.Error(),
|
||||
"permission denied",
|
||||
"Read-only user should NOT access tables in excluded schemas",
|
||||
)
|
||||
}
|
||||
|
||||
func connectToPostgresContainer(t *testing.T, port string) *PostgresContainer {
|
||||
dbName := "testdb"
|
||||
password := "testpassword"
|
||||
username := "testuser"
|
||||
host := "localhost"
|
||||
host := config.GetEnv().TestLocalhost
|
||||
|
||||
portInt, err := strconv.Atoi(port)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package databases
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
@@ -37,7 +40,22 @@ func GetDatabaseController() *DatabaseController {
|
||||
return databaseController
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(databaseService)
|
||||
notifiers.GetNotifierService().SetNotifierDatabaseCounter(databaseService)
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(databaseService)
|
||||
notifiers.GetNotifierService().SetNotifierDatabaseCounter(databaseService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package databases
|
||||
|
||||
import (
|
||||
"context"
|
||||
"databasus-backend/internal/features/databases/databases/mariadb"
|
||||
"databasus-backend/internal/features/databases/databases/mongodb"
|
||||
"databasus-backend/internal/features/databases/databases/mysql"
|
||||
@@ -84,6 +85,25 @@ func (d *Database) TestConnection(
|
||||
return d.getSpecificDatabase().TestConnection(logger, encryptor, d.ID)
|
||||
}
|
||||
|
||||
func (d *Database) IsUserReadOnly(
|
||||
ctx context.Context,
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
) (bool, []string, error) {
|
||||
switch d.Type {
|
||||
case DatabaseTypePostgres:
|
||||
return d.Postgresql.IsUserReadOnly(ctx, logger, encryptor, d.ID)
|
||||
case DatabaseTypeMysql:
|
||||
return d.Mysql.IsUserReadOnly(ctx, logger, encryptor, d.ID)
|
||||
case DatabaseTypeMariadb:
|
||||
return d.Mariadb.IsUserReadOnly(ctx, logger, encryptor, d.ID)
|
||||
case DatabaseTypeMongodb:
|
||||
return d.Mongodb.IsUserReadOnly(ctx, logger, encryptor, d.ID)
|
||||
default:
|
||||
return false, nil, errors.New("read-only check not supported for this database type")
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Database) HideSensitiveData() {
|
||||
d.getSpecificDatabase().HideSensitiveData()
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/databases/databases/mariadb"
|
||||
"databasus-backend/internal/features/databases/databases/mongodb"
|
||||
@@ -86,6 +87,23 @@ func (s *DatabaseService) CreateDatabase(
|
||||
return nil, fmt.Errorf("failed to auto-detect database data: %w", err)
|
||||
}
|
||||
|
||||
if config.GetEnv().IsCloud {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
isReadOnly, permissions, err := database.IsUserReadOnly(ctx, s.logger, s.fieldEncryptor)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to verify user permissions: %w", err)
|
||||
}
|
||||
|
||||
if !isReadOnly {
|
||||
return nil, fmt.Errorf(
|
||||
"in cloud mode, only read-only database users are allowed (user has permissions: %v)",
|
||||
permissions,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if err := database.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt sensitive fields: %w", err)
|
||||
}
|
||||
@@ -153,6 +171,29 @@ func (s *DatabaseService) UpdateDatabase(
|
||||
return fmt.Errorf("failed to auto-detect database data: %w", err)
|
||||
}
|
||||
|
||||
if config.GetEnv().IsCloud {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
isReadOnly, permissions, err := existingDatabase.IsUserReadOnly(
|
||||
ctx,
|
||||
s.logger,
|
||||
s.fieldEncryptor,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to verify user permissions: %w", err)
|
||||
}
|
||||
|
||||
if !isReadOnly {
|
||||
return fmt.Errorf(
|
||||
"in cloud mode, only read-only database users are allowed (user has permissions: %v)",
|
||||
permissions,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
oldName := existingDatabase.Name
|
||||
|
||||
if err := existingDatabase.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
|
||||
return fmt.Errorf("failed to encrypt sensitive fields: %w", err)
|
||||
}
|
||||
@@ -162,11 +203,23 @@ func (s *DatabaseService) UpdateDatabase(
|
||||
return err
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Database updated: %s", existingDatabase.Name),
|
||||
&user.ID,
|
||||
existingDatabase.WorkspaceID,
|
||||
)
|
||||
if oldName != existingDatabase.Name {
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
"Database updated and renamed from '%s' to '%s'",
|
||||
oldName,
|
||||
existingDatabase.Name,
|
||||
),
|
||||
&user.ID,
|
||||
existingDatabase.WorkspaceID,
|
||||
)
|
||||
} else {
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Database updated: %s", existingDatabase.Name),
|
||||
&user.ID,
|
||||
existingDatabase.WorkspaceID,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -532,9 +585,19 @@ func (s *DatabaseService) TransferDatabaseToWorkspace(
|
||||
return err
|
||||
}
|
||||
|
||||
sourceWorkspace, err := s.workspaceService.GetWorkspaceByID(*sourceWorkspaceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get source workspace: %w", err)
|
||||
}
|
||||
|
||||
targetWorkspace, err := s.workspaceService.GetWorkspaceByID(targetWorkspaceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get target workspace: %w", err)
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Database transferred: %s from workspace %s to workspace %s",
|
||||
database.Name, sourceWorkspaceID, targetWorkspaceID),
|
||||
fmt.Sprintf("Database transferred: %s from workspace '%s' to workspace '%s'",
|
||||
database.Name, sourceWorkspace.Name, targetWorkspace.Name),
|
||||
nil,
|
||||
&targetWorkspaceID,
|
||||
)
|
||||
@@ -649,38 +712,7 @@ func (s *DatabaseService) IsUserReadOnly(
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
switch usingDatabase.Type {
|
||||
case DatabaseTypePostgres:
|
||||
return usingDatabase.Postgresql.IsUserReadOnly(
|
||||
ctx,
|
||||
s.logger,
|
||||
s.fieldEncryptor,
|
||||
usingDatabase.ID,
|
||||
)
|
||||
case DatabaseTypeMysql:
|
||||
return usingDatabase.Mysql.IsUserReadOnly(
|
||||
ctx,
|
||||
s.logger,
|
||||
s.fieldEncryptor,
|
||||
usingDatabase.ID,
|
||||
)
|
||||
case DatabaseTypeMariadb:
|
||||
return usingDatabase.Mariadb.IsUserReadOnly(
|
||||
ctx,
|
||||
s.logger,
|
||||
s.fieldEncryptor,
|
||||
usingDatabase.ID,
|
||||
)
|
||||
case DatabaseTypeMongodb:
|
||||
return usingDatabase.Mongodb.IsUserReadOnly(
|
||||
ctx,
|
||||
s.logger,
|
||||
s.fieldEncryptor,
|
||||
usingDatabase.ID,
|
||||
)
|
||||
default:
|
||||
return false, nil, errors.New("read-only check not supported for this database type")
|
||||
}
|
||||
return usingDatabase.IsUserReadOnly(ctx, s.logger, s.fieldEncryptor)
|
||||
}
|
||||
|
||||
func (s *DatabaseService) CreateReadOnlyUser(
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
"databasus-backend/internal/storage"
|
||||
"databasus-backend/internal/util/tools"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -25,7 +26,7 @@ func GetTestPostgresConfig() *postgresql.PostgresqlDatabase {
|
||||
testDbName := "testdb"
|
||||
return &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
@@ -48,7 +49,7 @@ func GetTestMariadbConfig() *mariadb.MariadbDatabase {
|
||||
testDbName := "testdb"
|
||||
return &mariadb.MariadbDatabase{
|
||||
Version: tools.MariadbVersion1011,
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
@@ -69,13 +70,14 @@ func GetTestMongodbConfig() *mongodb.MongodbDatabase {
|
||||
|
||||
return &mongodb.MongodbDatabase{
|
||||
Version: tools.MongodbVersion7,
|
||||
Host: "localhost",
|
||||
Port: port,
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: &port,
|
||||
Username: "root",
|
||||
Password: "rootpassword",
|
||||
Database: "testdb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
}
|
||||
@@ -104,6 +106,19 @@ func CreateTestDatabase(
|
||||
}
|
||||
|
||||
func RemoveTestDatabase(database *Database) {
|
||||
// Delete backups and backup configs associated with this database
|
||||
// We hardcode SQL here because we cannot call backups feature due to DI inversion
|
||||
// (databases package cannot import backups package as backups already imports databases)
|
||||
db := storage.GetDb()
|
||||
|
||||
if err := db.Exec("DELETE FROM backups WHERE database_id = ?", database.ID).Error; err != nil {
|
||||
panic(fmt.Sprintf("failed to delete backups: %v", err))
|
||||
}
|
||||
|
||||
if err := db.Exec("DELETE FROM backup_configs WHERE database_id = ?", database.ID).Error; err != nil {
|
||||
panic(fmt.Sprintf("failed to delete backup config: %v", err))
|
||||
}
|
||||
|
||||
err := databaseRepository.Delete(database.ID)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
||||
@@ -12,6 +12,15 @@ import (
|
||||
type DiskService struct{}
|
||||
|
||||
func (s *DiskService) GetDiskUsage() (*DiskUsage, error) {
|
||||
if config.GetEnv().IsCloud {
|
||||
return &DiskUsage{
|
||||
Platform: PlatformLinux,
|
||||
TotalSpaceBytes: 100,
|
||||
UsedSpaceBytes: 0,
|
||||
FreeSpaceBytes: 100,
|
||||
}, nil
|
||||
}
|
||||
|
||||
platform := s.detectPlatform()
|
||||
|
||||
var path string
|
||||
|
||||
22
backend/internal/features/email/di.go
Normal file
22
backend/internal/features/email/di.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var env = config.GetEnv()
|
||||
var log = logger.GetLogger()
|
||||
|
||||
var emailSMTPSender = &EmailSMTPSender{
|
||||
log,
|
||||
env.SMTPHost,
|
||||
env.SMTPPort,
|
||||
env.SMTPUser,
|
||||
env.SMTPPassword,
|
||||
env.SMTPHost != "" && env.SMTPPort != 0,
|
||||
}
|
||||
|
||||
func GetEmailSMTPSender() *EmailSMTPSender {
|
||||
return emailSMTPSender
|
||||
}
|
||||
245
backend/internal/features/email/email.go
Normal file
245
backend/internal/features/email/email.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"mime"
|
||||
"net"
|
||||
"net/smtp"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
ImplicitTLSPort = 465
|
||||
DefaultTimeout = 5 * time.Second
|
||||
DefaultHelloName = "localhost"
|
||||
MIMETypeHTML = "text/html"
|
||||
MIMECharsetUTF8 = "UTF-8"
|
||||
)
|
||||
|
||||
type EmailSMTPSender struct {
|
||||
logger *slog.Logger
|
||||
smtpHost string
|
||||
smtpPort int
|
||||
smtpUser string
|
||||
smtpPassword string
|
||||
isConfigured bool
|
||||
}
|
||||
|
||||
func (s *EmailSMTPSender) SendEmail(to, subject, body string) error {
|
||||
if !s.isConfigured {
|
||||
s.logger.Warn("Skipping email send, SMTP not initialized", "to", to, "subject", subject)
|
||||
return nil
|
||||
}
|
||||
|
||||
from := s.smtpUser
|
||||
if from == "" {
|
||||
from = "noreply@" + s.smtpHost
|
||||
}
|
||||
|
||||
emailContent := s.buildEmailContent(to, subject, body, from)
|
||||
isAuthRequired := s.smtpUser != "" && s.smtpPassword != ""
|
||||
|
||||
if s.smtpPort == ImplicitTLSPort {
|
||||
return s.sendImplicitTLS(to, from, emailContent, isAuthRequired)
|
||||
}
|
||||
|
||||
return s.sendStartTLS(to, from, emailContent, isAuthRequired)
|
||||
}
|
||||
|
||||
func (s *EmailSMTPSender) buildEmailContent(to, subject, body, from string) []byte {
|
||||
// Encode Subject header using RFC 2047 to avoid SMTPUTF8 requirement
|
||||
encodedSubject := encodeRFC2047(subject)
|
||||
subjectHeader := fmt.Sprintf("Subject: %s\r\n", encodedSubject)
|
||||
dateHeader := fmt.Sprintf("Date: %s\r\n", time.Now().UTC().Format(time.RFC1123Z))
|
||||
|
||||
mimeHeaders := fmt.Sprintf(
|
||||
"MIME-version: 1.0;\nContent-Type: %s; charset=\"%s\";\n\n",
|
||||
MIMETypeHTML,
|
||||
MIMECharsetUTF8,
|
||||
)
|
||||
|
||||
// Encode From header display name if it contains non-ASCII
|
||||
encodedFrom := encodeRFC2047(from)
|
||||
fromHeader := fmt.Sprintf("From: %s\r\n", encodedFrom)
|
||||
|
||||
toHeader := fmt.Sprintf("To: %s\r\n", to)
|
||||
|
||||
return []byte(fromHeader + toHeader + subjectHeader + dateHeader + mimeHeaders + body)
|
||||
}
|
||||
|
||||
func (s *EmailSMTPSender) sendImplicitTLS(
|
||||
to, from string,
|
||||
emailContent []byte,
|
||||
isAuthRequired bool,
|
||||
) error {
|
||||
createClient := func() (*smtp.Client, func(), error) {
|
||||
return s.createImplicitTLSClient()
|
||||
}
|
||||
|
||||
client, cleanup, err := s.authenticateWithRetry(createClient, isAuthRequired)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
return s.sendEmail(client, to, from, emailContent)
|
||||
}
|
||||
|
||||
func (s *EmailSMTPSender) sendStartTLS(
|
||||
to, from string,
|
||||
emailContent []byte,
|
||||
isAuthRequired bool,
|
||||
) error {
|
||||
createClient := func() (*smtp.Client, func(), error) {
|
||||
return s.createStartTLSClient()
|
||||
}
|
||||
|
||||
client, cleanup, err := s.authenticateWithRetry(createClient, isAuthRequired)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
return s.sendEmail(client, to, from, emailContent)
|
||||
}
|
||||
|
||||
func (s *EmailSMTPSender) createImplicitTLSClient() (*smtp.Client, func(), error) {
|
||||
addr := net.JoinHostPort(s.smtpHost, fmt.Sprintf("%d", s.smtpPort))
|
||||
tlsConfig := &tls.Config{ServerName: s.smtpHost}
|
||||
dialer := &net.Dialer{Timeout: DefaultTimeout}
|
||||
|
||||
conn, err := tls.DialWithDialer(dialer, "tcp", addr, tlsConfig)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to connect to SMTP server: %w", err)
|
||||
}
|
||||
|
||||
client, err := smtp.NewClient(conn, s.smtpHost)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, nil, fmt.Errorf("failed to create SMTP client: %w", err)
|
||||
}
|
||||
|
||||
return client, func() { _ = client.Quit() }, nil
|
||||
}
|
||||
|
||||
func (s *EmailSMTPSender) createStartTLSClient() (*smtp.Client, func(), error) {
|
||||
addr := net.JoinHostPort(s.smtpHost, fmt.Sprintf("%d", s.smtpPort))
|
||||
dialer := &net.Dialer{Timeout: DefaultTimeout}
|
||||
|
||||
conn, err := dialer.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to connect to SMTP server: %w", err)
|
||||
}
|
||||
|
||||
client, err := smtp.NewClient(conn, s.smtpHost)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, nil, fmt.Errorf("failed to create SMTP client: %w", err)
|
||||
}
|
||||
|
||||
if err := client.Hello(DefaultHelloName); err != nil {
|
||||
_ = client.Quit()
|
||||
_ = conn.Close()
|
||||
return nil, nil, fmt.Errorf("SMTP hello failed: %w", err)
|
||||
}
|
||||
|
||||
if ok, _ := client.Extension("STARTTLS"); ok {
|
||||
if err := client.StartTLS(&tls.Config{ServerName: s.smtpHost}); err != nil {
|
||||
_ = client.Quit()
|
||||
_ = conn.Close()
|
||||
return nil, nil, fmt.Errorf("STARTTLS failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return client, func() { _ = client.Quit() }, nil
|
||||
}
|
||||
|
||||
func (s *EmailSMTPSender) authenticateWithRetry(
|
||||
createClient func() (*smtp.Client, func(), error),
|
||||
isAuthRequired bool,
|
||||
) (*smtp.Client, func(), error) {
|
||||
client, cleanup, err := createClient()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if !isAuthRequired {
|
||||
return client, cleanup, nil
|
||||
}
|
||||
|
||||
// Try PLAIN auth first
|
||||
plainAuth := smtp.PlainAuth("", s.smtpUser, s.smtpPassword, s.smtpHost)
|
||||
if err := client.Auth(plainAuth); err == nil {
|
||||
return client, cleanup, nil
|
||||
}
|
||||
|
||||
// PLAIN auth failed, connection may be closed - recreate and try LOGIN auth
|
||||
cleanup()
|
||||
|
||||
client, cleanup, err = createClient()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
loginAuth := &loginAuth{username: s.smtpUser, password: s.smtpPassword}
|
||||
if err := client.Auth(loginAuth); err != nil {
|
||||
cleanup()
|
||||
return nil, nil, fmt.Errorf("SMTP authentication failed: %w", err)
|
||||
}
|
||||
|
||||
return client, cleanup, nil
|
||||
}
|
||||
|
||||
func (s *EmailSMTPSender) sendEmail(client *smtp.Client, to, from string, content []byte) error {
|
||||
if err := client.Mail(from); err != nil {
|
||||
return fmt.Errorf("failed to set sender: %w", err)
|
||||
}
|
||||
|
||||
if err := client.Rcpt(to); err != nil {
|
||||
return fmt.Errorf("failed to set recipient: %w", err)
|
||||
}
|
||||
|
||||
writer, err := client.Data()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get data writer: %w", err)
|
||||
}
|
||||
|
||||
if _, err = writer.Write(content); err != nil {
|
||||
return fmt.Errorf("failed to write email content: %w", err)
|
||||
}
|
||||
|
||||
if err = writer.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close data writer: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func encodeRFC2047(s string) string {
|
||||
return mime.QEncoding.Encode("UTF-8", s)
|
||||
}
|
||||
|
||||
type loginAuth struct {
|
||||
username string
|
||||
password string
|
||||
}
|
||||
|
||||
func (a *loginAuth) Start(server *smtp.ServerInfo) (string, []byte, error) {
|
||||
return "LOGIN", []byte{}, nil
|
||||
}
|
||||
|
||||
func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) {
|
||||
if more {
|
||||
switch string(fromServer) {
|
||||
case "Username:", "User Name\x00":
|
||||
return []byte(a.username), nil
|
||||
case "Password:", "Password\x00":
|
||||
return []byte(a.password), nil
|
||||
default:
|
||||
return []byte(a.username), nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
@@ -1,30 +1,48 @@
|
||||
package healthcheck_attempt
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
|
||||
)
|
||||
|
||||
type HealthcheckAttemptBackgroundService struct {
|
||||
healthcheckConfigService *healthcheck_config.HealthcheckConfigService
|
||||
checkDatabaseHealthUseCase *CheckDatabaseHealthUseCase
|
||||
logger *slog.Logger
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *HealthcheckAttemptBackgroundService) Run() {
|
||||
// first healthcheck immediately
|
||||
s.checkDatabases()
|
||||
func (s *HealthcheckAttemptBackgroundService) Run(ctx context.Context) {
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
if config.IsShouldShutdown() {
|
||||
break
|
||||
}
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
|
||||
// first healthcheck immediately
|
||||
s.checkDatabases()
|
||||
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.checkDatabases()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -144,6 +144,10 @@ func Test_GetAttemptsByDatabase_PermissionsEnforced(t *testing.T) {
|
||||
)
|
||||
assert.Contains(t, string(testResp.Body), "forbidden")
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -181,6 +185,10 @@ func Test_GetAttemptsByDatabase_FiltersByAfterDate(t *testing.T) {
|
||||
for _, attempt := range response {
|
||||
assert.True(t, attempt.CreatedAt.After(afterDate) || attempt.CreatedAt.Equal(afterDate))
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_GetAttemptsByDatabase_ReturnsEmptyListForNewDatabase(t *testing.T) {
|
||||
@@ -201,6 +209,10 @@ func Test_GetAttemptsByDatabase_ReturnsEmptyListForNewDatabase(t *testing.T) {
|
||||
)
|
||||
|
||||
assert.Equal(t, 0, len(response))
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func createTestDatabaseViaAPI(
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package healthcheck_attempt
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"databasus-backend/internal/features/databases"
|
||||
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
@@ -22,9 +25,11 @@ var checkDatabaseHealthUseCase = &CheckDatabaseHealthUseCase{
|
||||
}
|
||||
|
||||
var healthcheckAttemptBackgroundService = &HealthcheckAttemptBackgroundService{
|
||||
healthcheck_config.GetHealthcheckConfigService(),
|
||||
checkDatabaseHealthUseCase,
|
||||
logger.GetLogger(),
|
||||
healthcheckConfigService: healthcheck_config.GetHealthcheckConfigService(),
|
||||
checkDatabaseHealthUseCase: checkDatabaseHealthUseCase,
|
||||
logger: logger.GetLogger(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
var healthcheckAttemptController = &HealthcheckAttemptController{
|
||||
healthcheckAttemptService,
|
||||
|
||||
@@ -130,6 +130,10 @@ func Test_SaveHealthcheckConfig_PermissionsEnforced(t *testing.T) {
|
||||
)
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -162,6 +166,10 @@ func Test_SaveHealthcheckConfig_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_GetHealthcheckConfig_PermissionsEnforced(t *testing.T) {
|
||||
@@ -268,6 +276,10 @@ func Test_GetHealthcheckConfig_PermissionsEnforced(t *testing.T) {
|
||||
)
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -295,6 +307,10 @@ func Test_GetHealthcheckConfig_ReturnsDefaultConfigForNewDatabase(t *testing.T)
|
||||
assert.Equal(t, 1, response.IntervalMinutes)
|
||||
assert.Equal(t, 3, response.AttemptsBeforeConcideredAsDown)
|
||||
assert.Equal(t, 7, response.StoreAttemptsDays)
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func createTestDatabaseViaAPI(
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package healthcheck_config
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/databases"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
@@ -27,8 +30,23 @@ func GetHealthcheckConfigController() *HealthcheckConfigController {
|
||||
return healthcheckConfigController
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
databases.
|
||||
GetDatabaseService().
|
||||
AddDbCreationListener(healthcheckConfigService)
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
databases.
|
||||
GetDatabaseService().
|
||||
AddDbCreationListener(healthcheckConfigService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -675,6 +675,10 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
WebhookNotifier: &webhook_notifier.WebhookNotifier{
|
||||
WebhookURL: "https://webhook.example.com/test",
|
||||
WebhookMethod: webhook_notifier.WebhookMethodPOST,
|
||||
Headers: []webhook_notifier.WebhookHeader{
|
||||
{Key: "Authorization", Value: "Bearer my-secret-token"},
|
||||
{Key: "X-Custom-Header", Value: "custom-value"},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
@@ -687,14 +691,40 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
WebhookNotifier: &webhook_notifier.WebhookNotifier{
|
||||
WebhookURL: "https://webhook.example.com/updated",
|
||||
WebhookMethod: webhook_notifier.WebhookMethodGET,
|
||||
Headers: []webhook_notifier.WebhookHeader{
|
||||
{Key: "Authorization", Value: "Bearer updated-token"},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
|
||||
// No sensitive data to verify for webhook
|
||||
assert.NotEmpty(
|
||||
t,
|
||||
notifier.WebhookNotifier.WebhookURL,
|
||||
"WebhookURL should be visible",
|
||||
)
|
||||
// Verify header values are encrypted in DB
|
||||
assert.True(
|
||||
t,
|
||||
isEncrypted(notifier.WebhookNotifier.Headers[0].Value),
|
||||
"Header value should be encrypted in DB",
|
||||
)
|
||||
decrypted := decryptField(
|
||||
t,
|
||||
notifier.ID,
|
||||
notifier.WebhookNotifier.Headers[0].Value,
|
||||
)
|
||||
assert.Equal(t, "Bearer updated-token", decrypted)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
|
||||
// No sensitive data to hide for webhook
|
||||
assert.NotEmpty(
|
||||
t,
|
||||
notifier.WebhookNotifier.WebhookURL,
|
||||
"WebhookURL should be visible",
|
||||
)
|
||||
for _, header := range notifier.WebhookNotifier.Headers {
|
||||
assert.Empty(t, header.Value, "Header value should be hidden")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -905,7 +935,7 @@ func Test_CreateNotifier_AllSensitiveFieldsEncryptedInDB(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Webhook Notifier - WebhookURL encrypted",
|
||||
name: "Webhook Notifier - Header values encrypted, URL not encrypted",
|
||||
createNotifier: func(workspaceID uuid.UUID) *Notifier {
|
||||
return &Notifier{
|
||||
WorkspaceID: workspaceID,
|
||||
@@ -914,17 +944,48 @@ func Test_CreateNotifier_AllSensitiveFieldsEncryptedInDB(t *testing.T) {
|
||||
WebhookNotifier: &webhook_notifier.WebhookNotifier{
|
||||
WebhookURL: "https://webhook.example.com/test456",
|
||||
WebhookMethod: webhook_notifier.WebhookMethodPOST,
|
||||
Headers: []webhook_notifier.WebhookHeader{
|
||||
{Key: "Authorization", Value: "Bearer secret-token-12345"},
|
||||
{Key: "X-API-Key", Value: "api-key-67890"},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
|
||||
assert.True(
|
||||
assert.False(
|
||||
t,
|
||||
isEncrypted(notifier.WebhookNotifier.WebhookURL),
|
||||
"WebhookURL should be encrypted",
|
||||
"WebhookURL should NOT be encrypted",
|
||||
)
|
||||
decrypted := decryptField(t, notifier.ID, notifier.WebhookNotifier.WebhookURL)
|
||||
assert.Equal(t, "https://webhook.example.com/test456", decrypted)
|
||||
assert.Equal(
|
||||
t,
|
||||
"https://webhook.example.com/test456",
|
||||
notifier.WebhookNotifier.WebhookURL,
|
||||
)
|
||||
|
||||
assert.True(
|
||||
t,
|
||||
isEncrypted(notifier.WebhookNotifier.Headers[0].Value),
|
||||
"Header value should be encrypted",
|
||||
)
|
||||
decrypted1 := decryptField(
|
||||
t,
|
||||
notifier.ID,
|
||||
notifier.WebhookNotifier.Headers[0].Value,
|
||||
)
|
||||
assert.Equal(t, "Bearer secret-token-12345", decrypted1)
|
||||
|
||||
assert.True(
|
||||
t,
|
||||
isEncrypted(notifier.WebhookNotifier.Headers[1].Value),
|
||||
"Header value should be encrypted",
|
||||
)
|
||||
decrypted2 := decryptField(
|
||||
t,
|
||||
notifier.ID,
|
||||
notifier.WebhookNotifier.Headers[1].Value,
|
||||
)
|
||||
assert.Equal(t, "api-key-67890", decrypted2)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package notifiers
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
@@ -32,6 +35,22 @@ func GetNotifierService() *NotifierService {
|
||||
func GetNotifierRepository() *NotifierRepository {
|
||||
return notifierRepository
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(notifierService)
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(notifierService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user