Compare commits

...

68 Commits

Author SHA1 Message Date
Rostislav Dugin
e798d82fc1 Merge pull request #325 from databasus/develop
FIX (storages): Fix default storage type prefill in playground
2026-02-01 20:12:12 +03:00
Rostislav Dugin
81a01585ee FIX (storages): Fix default storage type prefill in playground 2026-02-01 20:07:12 +03:00
Rostislav Dugin
a8465c1a10 Merge pull request #324 from databasus/develop
FIX (storages): Limit local storage usage in playground
2026-02-01 19:20:34 +03:00
Rostislav Dugin
a9e5db70f6 FIX (storages): Limit local storage usage in playground 2026-02-01 19:18:54 +03:00
Rostislav Dugin
7a47be6ca6 Merge pull request #323 from databasus/develop
Develop
2026-02-01 18:42:30 +03:00
Rostislav Dugin
16be3db0c6 FIX (playground): Pre-select system storage if exists in playground 2026-02-01 18:30:50 +03:00
Rostislav Dugin
744e51d1e1 REFACTOR (email): Refactor commit adding date headers to emails 2026-02-01 16:43:53 +03:00
Rostislav Dugin
b3af75d430 Merge branch 'develop' of https://github.com/databasus/databasus into develop 2026-02-01 16:41:52 +03:00
mcarbs
6f7320abeb FIX (email): Add email date header 2026-02-01 16:41:17 +03:00
Rostislav Dugin
a1655d35a6 FIX (healthcheck): Add cache accessibility to healthcheck 2026-01-30 16:33:39 +03:00
Rostislav Dugin
9b6e801184 Merge pull request #316 from databasus/develop
FEATURE (email): Add sending email about members invitation and passw…
2026-01-28 17:29:58 +03:00
Rostislav Dugin
105777ab6f FEATURE (email): Add sending email about members invitation and password reset 2026-01-28 17:28:36 +03:00
Rostislav Dugin
3a1a88d5cf Merge pull request #315 from databasus/develop
FIX (env): Fix env detection over startup
2026-01-28 11:33:06 +03:00
Rostislav Dugin
699ca16814 FIX (env): Fix env detection over startup 2026-01-28 11:32:19 +03:00
Rostislav Dugin
26f3cf233a Merge pull request #313 from databasus/develop
FIX (backups): Improve cascade deletion of backups on storage removal x3
2026-01-27 17:04:25 +03:00
Rostislav Dugin
3d8372e9f6 FIX (backups): Improve cascade deletion of backups on storage removal x3 2026-01-27 17:03:51 +03:00
Rostislav Dugin
b46f11804d Merge pull request #312 from databasus/develop
FIX (backups): Improve cascade deletion of backups on storage removal x2
2026-01-27 16:38:49 +03:00
Rostislav Dugin
4676361688 FIX (backups): Improve cascade deletion of backups on storage removal x2 2026-01-27 16:38:21 +03:00
Databasus
de3679cadf Merge pull request #310 from databasus/develop
FIX (backups): Improve cascade deletion of backups on storage removal
2026-01-27 16:29:13 +03:00
Rostislav Dugin
8f03a30af2 FIX (backups): Improve cascade deletion of backups on storage removal 2026-01-27 16:28:06 +03:00
Rostislav Dugin
356529c58a Merge pull request #309 from databasus/develop
FIX (tests): Fix database backups cleanup when DI does not allow to d…
2026-01-27 15:39:53 +03:00
Rostislav Dugin
e7eed056f7 FIX (tests): Fix database backups cleanup when DI does not allow to delete backups via listeners 2026-01-27 15:39:04 +03:00
Rostislav Dugin
6084cdc954 Merge pull request #308 from databasus/develop
FIX (tests): Increase cascade deletion timeouts in tests
2026-01-27 15:24:15 +03:00
Rostislav Dugin
c50bcc57b1 FIX (tests): Increase cascade deletion timeouts in tests 2026-01-27 15:23:13 +03:00
Rostislav Dugin
ea76300ed7 Merge pull request #307 from databasus/develop
Develop
2026-01-27 15:07:56 +03:00
Rostislav Dugin
9b413e4076 FIX (tests): Improve cleaning up of backups and workspaces 2026-01-27 15:07:20 +03:00
Rostislav Dugin
f91cb260f2 FEATURE (logs): Add Victora Logs 2026-01-27 15:07:20 +03:00
Rostislav Dugin
8f37a8082f FIX (db): Decrease connections count for DB 2026-01-27 15:07:20 +03:00
Rostislav Dugin
5cf7614772 FIX (playground): Make playground multiple nodes 2026-01-24 14:57:45 +03:00
Rostislav Dugin
ae27f74c2e Merge pull request #304 from databasus/develop
FIX (playground): Fix flacky test with impossible value
2026-01-23 12:38:06 +03:00
Rostislav Dugin
9457516bb9 FIX (playground): Fix flacky test with impossible value 2026-01-23 12:37:10 +03:00
Rostislav Dugin
a36fc5bf8c Merge pull request #303 from databasus/develop
Develop
2026-01-23 12:24:29 +03:00
Rostislav Dugin
03ada5806d FEATURE (pre-commit): Add building step to pre-commit 2026-01-23 12:22:31 +03:00
Rostislav Dugin
a6675390e5 FIX (cors): Allow CORS for healthcheck endpoint 2026-01-23 12:04:29 +03:00
Rostislav Dugin
af2f978876 FEATURE (playground): Add playground 2026-01-23 12:00:56 +03:00
Rostislav Dugin
04e7eba5c5 Merge pull request #300 from databasus/develop
FIX (ci \ cd): Add build step after lint step for frontend to catch b…
2026-01-20 08:40:14 +03:00
Rostislav Dugin
520165541d FIX (ci \ cd): Add build step after lint step for frontend to catch build issues 2026-01-20 08:39:28 +03:00
Rostislav Dugin
5b556bc161 Merge pull request #299 from databasus/develop
Develop
2026-01-20 08:26:57 +03:00
Rostislav Dugin
0952a15ec5 FEATURE (navbar): Update navbar style 2026-01-20 08:25:58 +03:00
Rostislav Dugin
1afb3aa3ff Merge pull request #298 from tim-sas-kramp/main
FIX (theme): Integrate theme support for GitHub button color scheme
2026-01-20 07:25:57 +03:00
tim-sas-kramp
19b92e5f74 FIX (theme): Integrate theme support for GitHub button color scheme 2026-01-19 21:17:24 +00:00
Rostislav Dugin
d4763f26b2 Merge pull request #296 from databasus/develop
Develop
2026-01-19 19:27:03 +03:00
Rostislav Dugin
0e389ba16b FIX (backups): Allow parallel backups for different DBs 2026-01-19 19:26:03 +03:00
Rostislav Dugin
594a3294c6 FEATURE (limits): Add max backup size limit and total backups size limit 2026-01-19 19:26:03 +03:00
Rostislav Dugin
4e4a323cf1 FEATURE (config): Suggest read-only user creation when DB config changed 2026-01-19 19:26:03 +03:00
Rostislav Dugin
7d9ecf697b FIX (backups): Do not allow 2 parallel backups for the same DB 2026-01-19 19:26:03 +03:00
Rostislav Dugin
755c420157 Merge pull request #294 from databasus/develop
FIX (mysql \ mariadb): Add escaping underscoped DB names over heath c…
2026-01-19 12:07:18 +03:00
Rostislav Dugin
ff73627287 FIX (mysql \ mariadb): Add escaping underscoped DB names over heath check 2026-01-19 11:34:37 +03:00
Rostislav Dugin
9c9ab00ace Merge pull request #292 from databasus/develop
FIX (postgresql): Do not throw an error over read-only user creation …
2026-01-18 23:08:55 +03:00
Rostislav Dugin
7366e21a1a FIX (postgresql): Do not throw an error over read-only user creation if there are no public schema in DB 2026-01-18 22:57:47 +03:00
Rostislav Dugin
a327d1aa57 Merge pull request #290 from databasus/develop
FIX (ftp): Add support of nested folders
2026-01-18 18:34:45 +03:00
Rostislav Dugin
f152b16ea3 FIX (ftp): Add support of nested folders 2026-01-18 18:34:13 +03:00
Databasus
85dbe80d3d Merge pull request #288 from databasus/develop
FIX (email): Add following RFC 2047 for emails
2026-01-18 17:59:17 +03:00
Rostislav Dugin
edf4028fd1 FIX (email): Add following RFC 2047 for emails 2026-01-18 17:58:31 +03:00
Databasus
8d85c45a90 Merge pull request #287 from databasus/develop
FIX (tests): Allow to skip external network tests in CI CD
2026-01-18 15:46:49 +03:00
Rostislav Dugin
d9c176d19a FIX (tests): Allow to skip external network tests in CI CD 2026-01-18 15:45:49 +03:00
Databasus
7a6f72a456 Merge pull request #286 from databasus/develop
FIX (ci): Add cleanup to build and push steps
2026-01-18 15:09:13 +03:00
Rostislav Dugin
9a1471b88b FIX (ci): Add cleanup to build and push steps 2026-01-18 15:08:09 +03:00
Databasus
386ea1d708 Merge pull request #285 from databasus/develop
FIX (commit messages): Allow to use backstashes in messages x3
2026-01-18 14:58:10 +03:00
Rostislav Dugin
a4b23936ee FIX (commit messages): Allow to use backstashes in messages x3 2026-01-18 14:57:45 +03:00
Databasus
b36aa9d48b Merge pull request #284 from databasus/develop
FIX (commit messages): Allow to use backstashes in messages x2
2026-01-18 14:49:58 +03:00
Rostislav Dugin
13cb8e5bd2 FIX (commit messages): Allow to use backstashes in messages x2 2026-01-18 14:49:18 +03:00
Databasus
2db4b6e075 Merge pull request #283 from databasus/develop
FIX (commit messages): Allow to use backstashes in messages
2026-01-18 14:38:34 +03:00
Rostislav Dugin
f2b0b2bf1f FIX (commit messages): Allow to use backstashes in messages 2026-01-18 14:38:12 +03:00
Databasus
7142ce295e Merge pull request #282 from databasus/develop
Develop
2026-01-18 14:01:59 +03:00
Rostislav Dugin
04621b9b2d FEATURE (ci \ cd): Adjust CI \ CD to run heavy jobs on self hosted performant runner 2026-01-18 13:55:08 +03:00
Rostislav Dugin
bd329a68cf FEATURE (restores): Do not allow to make 2 parallel restores for single DB 2026-01-17 22:50:35 +03:00
Rostislav Dugin
f957abc9db FEATURE (restores): Add cancellation of restore process 2026-01-17 22:35:47 +03:00
189 changed files with 10417 additions and 1416 deletions

View File

@@ -9,15 +9,26 @@ on:
jobs:
lint-backend:
runs-on: ubuntu-latest
runs-on: self-hosted
container:
image: golang:1.24.9
volumes:
- /runner-cache/go-pkg:/go/pkg/mod
- /runner-cache/go-build:/root/.cache/go-build
- /runner-cache/golangci-lint:/root/.cache/golangci-lint
- /runner-cache/apt-archives:/var/cache/apt/archives
steps:
- name: Check out code
uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: "1.24.9"
- name: Configure Git for container
run: |
git config --global --add safe.directory "$GITHUB_WORKSPACE"
- name: Download Go modules
run: |
cd backend
go mod download
- name: Install golangci-lint
run: |
@@ -70,6 +81,11 @@ jobs:
cd frontend
npm run lint
- name: Build frontend
run: |
cd frontend
npm run build
test-frontend:
runs-on: ubuntu-latest
needs: [lint-frontend]
@@ -93,34 +109,32 @@ jobs:
npm run test
test-backend:
runs-on: ubuntu-latest
runs-on: self-hosted
needs: [lint-backend]
container:
image: golang:1.24.9
options: --privileged -v /var/run/docker.sock:/var/run/docker.sock --add-host=host.docker.internal:host-gateway
volumes:
- /runner-cache/go-pkg:/go/pkg/mod
- /runner-cache/go-build:/root/.cache/go-build
- /runner-cache/apt-archives:/var/cache/apt/archives
steps:
- name: Free up disk space
- name: Install Docker CLI
run: |
echo "Disk space before cleanup:"
df -h
# Remove unnecessary pre-installed software
sudo rm -rf /usr/share/dotnet
sudo rm -rf /usr/local/lib/android
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL
sudo rm -rf /usr/local/share/boost
sudo rm -rf /usr/share/swift
# Clean apt cache
sudo apt-get clean
# Clean docker images (if any pre-installed)
docker system prune -af --volumes || true
echo "Disk space after cleanup:"
df -h
apt-get update -qq
apt-get install -y -qq docker.io docker-compose netcat-openbsd wget
- name: Check out code
uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: "1.24.9"
- name: Configure Git for container
run: |
git config --global --add safe.directory "$GITHUB_WORKSPACE"
- name: Download Go modules
run: |
cd backend
go mod download
- name: Create .env file for testing
run: |
@@ -132,14 +146,16 @@ jobs:
DEV_DB_PASSWORD=Q1234567
#app
ENV_MODE=development
# db
DATABASE_DSN=host=localhost user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
DATABASE_URL=postgres://postgres:Q1234567@localhost:5437/databasus?sslmode=disable
# db - using 172.17.0.1 to access host from container
DATABASE_DSN=host=172.17.0.1 user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
DATABASE_URL=postgres://postgres:Q1234567@172.17.0.1:5437/databasus?sslmode=disable
# migrations
GOOSE_DRIVER=postgres
GOOSE_DBSTRING=postgres://postgres:Q1234567@localhost:5437/databasus?sslmode=disable
GOOSE_DBSTRING=postgres://postgres:Q1234567@172.17.0.1:5437/databasus?sslmode=disable
GOOSE_MIGRATION_DIR=./migrations
# testing
# testing
TEST_LOCALHOST=172.17.0.1
IS_SKIP_EXTERNAL_RESOURCES_TESTS=true
# to get Google Drive env variables: add storage in UI and copy data from added storage here
TEST_GOOGLE_DRIVE_CLIENT_ID=${{ secrets.TEST_GOOGLE_DRIVE_CLIENT_ID }}
TEST_GOOGLE_DRIVE_CLIENT_SECRET=${{ secrets.TEST_GOOGLE_DRIVE_CLIENT_SECRET }}
@@ -197,12 +213,14 @@ jobs:
TEST_MONGODB_60_PORT=27060
TEST_MONGODB_70_PORT=27070
TEST_MONGODB_82_PORT=27082
# Valkey (cache)
VALKEY_HOST=localhost
# Valkey (cache) - using 172.17.0.1
VALKEY_HOST=172.17.0.1
VALKEY_PORT=6379
VALKEY_USERNAME=
VALKEY_PASSWORD=
VALKEY_IS_SSL=false
# Host for test databases (container -> host)
TEST_DB_HOST=172.17.0.1
EOF
- name: Start test containers
@@ -220,25 +238,25 @@ jobs:
timeout 60 bash -c 'until docker exec dev-valkey valkey-cli ping 2>/dev/null | grep -q PONG; do sleep 2; done'
echo "Valkey is ready!"
# Wait for test databases
timeout 60 bash -c 'until nc -z localhost 5000; do sleep 2; done'
timeout 60 bash -c 'until nc -z localhost 5001; do sleep 2; done'
timeout 60 bash -c 'until nc -z localhost 5002; do sleep 2; done'
timeout 60 bash -c 'until nc -z localhost 5003; do sleep 2; done'
timeout 60 bash -c 'until nc -z localhost 5004; do sleep 2; done'
timeout 60 bash -c 'until nc -z localhost 5005; do sleep 2; done'
# Wait for test databases (using 172.17.0.1 from container)
timeout 60 bash -c 'until nc -z 172.17.0.1 5000; do sleep 2; done'
timeout 60 bash -c 'until nc -z 172.17.0.1 5001; do sleep 2; done'
timeout 60 bash -c 'until nc -z 172.17.0.1 5002; do sleep 2; done'
timeout 60 bash -c 'until nc -z 172.17.0.1 5003; do sleep 2; done'
timeout 60 bash -c 'until nc -z 172.17.0.1 5004; do sleep 2; done'
timeout 60 bash -c 'until nc -z 172.17.0.1 5005; do sleep 2; done'
# Wait for MinIO
timeout 60 bash -c 'until nc -z localhost 9000; do sleep 2; done'
timeout 60 bash -c 'until nc -z 172.17.0.1 9000; do sleep 2; done'
# Wait for Azurite
timeout 60 bash -c 'until nc -z localhost 10000; do sleep 2; done'
timeout 60 bash -c 'until nc -z 172.17.0.1 10000; do sleep 2; done'
# Wait for FTP
timeout 60 bash -c 'until nc -z localhost 7007; do sleep 2; done'
timeout 60 bash -c 'until nc -z 172.17.0.1 7007; do sleep 2; done'
# Wait for SFTP
timeout 60 bash -c 'until nc -z localhost 7008; do sleep 2; done'
timeout 60 bash -c 'until nc -z 172.17.0.1 7008; do sleep 2; done'
# Wait for MySQL containers
echo "Waiting for MySQL 5.7..."
@@ -297,63 +315,63 @@ jobs:
mkdir -p databasus-data/backups
mkdir -p databasus-data/temp
- name: Install MySQL dependencies
- name: Install database client dependencies
run: |
sudo apt-get update -qq
sudo apt-get install -y -qq libncurses6
sudo ln -sf /usr/lib/x86_64-linux-gnu/libncurses.so.6 /usr/lib/x86_64-linux-gnu/libncurses.so.5
sudo ln -sf /usr/lib/x86_64-linux-gnu/libtinfo.so.6 /usr/lib/x86_64-linux-gnu/libtinfo.so.5
apt-get update -qq
apt-get install -y -qq libncurses6 libpq5
ln -sf /usr/lib/x86_64-linux-gnu/libncurses.so.6 /usr/lib/x86_64-linux-gnu/libncurses.so.5 || true
ln -sf /usr/lib/x86_64-linux-gnu/libtinfo.so.6 /usr/lib/x86_64-linux-gnu/libtinfo.so.5 || true
- name: Setup PostgreSQL, MySQL and MariaDB client tools from pre-built assets
run: |
cd backend/tools
# Create directory structure
mkdir -p postgresql mysql mariadb mongodb/bin
# Copy PostgreSQL client tools (12-18) from pre-built assets
for version in 12 13 14 15 16 17 18; do
mkdir -p postgresql/postgresql-$version
cp -r ../../assets/tools/x64/postgresql/postgresql-$version/bin postgresql/postgresql-$version/
done
# Copy MySQL client tools (5.7, 8.0, 8.4, 9) from pre-built assets
for version in 5.7 8.0 8.4 9; do
mkdir -p mysql/mysql-$version
cp -r ../../assets/tools/x64/mysql/mysql-$version/bin mysql/mysql-$version/
done
# Copy MariaDB client tools (10.6, 12.1) from pre-built assets
for version in 10.6 12.1; do
mkdir -p mariadb/mariadb-$version
cp -r ../../assets/tools/x64/mariadb/mariadb-$version/bin mariadb/mariadb-$version/
done
# Make all binaries executable
chmod +x postgresql/*/bin/*
chmod +x mysql/*/bin/*
chmod +x mariadb/*/bin/*
echo "Pre-built client tools setup complete"
- name: Install MongoDB Database Tools
run: |
cd backend/tools
# MongoDB Database Tools must be downloaded (not in pre-built assets)
# They are backward compatible - single version supports all servers (4.0-8.0)
MONGODB_TOOLS_URL="https://fastdl.mongodb.org/tools/db/mongodb-database-tools-debian12-x86_64-100.10.0.deb"
echo "Downloading MongoDB Database Tools..."
wget -q "$MONGODB_TOOLS_URL" -O /tmp/mongodb-database-tools.deb
echo "Installing MongoDB Database Tools..."
sudo dpkg -i /tmp/mongodb-database-tools.deb || sudo apt-get install -f -y --no-install-recommends
dpkg -i /tmp/mongodb-database-tools.deb || apt-get install -f -y --no-install-recommends
# Create symlinks to tools directory
ln -sf /usr/bin/mongodump mongodb/bin/mongodump
ln -sf /usr/bin/mongorestore mongodb/bin/mongorestore
rm -f /tmp/mongodb-database-tools.deb
echo "MongoDB Database Tools installed successfully"
@@ -401,10 +419,28 @@ jobs:
if: always()
run: |
cd backend
# Stop and remove containers (keeping images for next run)
docker compose -f docker-compose.yml.example down -v
# Clean up all data directories created by docker-compose
echo "Cleaning up data directories..."
rm -rf pgdata || true
rm -rf valkey-data || true
rm -rf mysqldata || true
rm -rf mariadbdata || true
rm -rf temp/nas || true
rm -rf databasus-data || true
# Also clean root-level databasus-data if exists
cd ..
rm -rf databasus-data || true
echo "Cleanup complete"
determine-version:
runs-on: ubuntu-latest
runs-on: self-hosted
container:
image: node:20
needs: [test-backend, test-frontend]
if: ${{ github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message, '[skip-release]') }}
outputs:
@@ -417,10 +453,9 @@ jobs:
with:
fetch-depth: 0
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "20"
- name: Configure Git for container
run: |
git config --global --add safe.directory "$GITHUB_WORKSPACE"
- name: Install semver
run: npm install -g semver
@@ -434,6 +469,7 @@ jobs:
- name: Analyze commits and determine version bump
id: version_bump
shell: bash
run: |
CURRENT_VERSION="${{ steps.current_version.outputs.current_version }}"
LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
@@ -453,7 +489,7 @@ jobs:
HAS_FIX=false
HAS_BREAKING=false
# Analyze each commit
# Analyze each commit - USE PROCESS SUBSTITUTION to avoid subshell variable scope issues
while IFS= read -r commit; do
if [[ "$commit" =~ ^FEATURE ]]; then
HAS_FEATURE=true
@@ -471,7 +507,7 @@ jobs:
HAS_BREAKING=true
echo "Found BREAKING CHANGE: $commit"
fi
done <<< "$COMMITS"
done < <(printf '%s\n' "$COMMITS")
# Determine version bump
if [ "$HAS_BREAKING" = true ]; then
@@ -497,10 +533,15 @@ jobs:
fi
build-only:
runs-on: ubuntu-latest
runs-on: self-hosted
needs: [test-backend, test-frontend]
if: ${{ github.ref == 'refs/heads/main' && contains(github.event.head_commit.message, '[skip-release]') }}
steps:
- name: Clean workspace
run: |
sudo rm -rf "$GITHUB_WORKSPACE"/* || true
sudo rm -rf "$GITHUB_WORKSPACE"/.* || true
- name: Check out code
uses: actions/checkout@v4
@@ -529,12 +570,17 @@ jobs:
databasus/databasus:${{ github.sha }}
build-and-push:
runs-on: ubuntu-latest
runs-on: self-hosted
needs: [determine-version]
if: ${{ needs.determine-version.outputs.should_release == 'true' }}
permissions:
contents: write
steps:
- name: Clean workspace
run: |
sudo rm -rf "$GITHUB_WORKSPACE"/* || true
sudo rm -rf "$GITHUB_WORKSPACE"/.* || true
- name: Check out code
uses: actions/checkout@v4
@@ -564,21 +610,33 @@ jobs:
databasus/databasus:${{ github.sha }}
release:
runs-on: ubuntu-latest
runs-on: self-hosted
container:
image: node:20
needs: [determine-version, build-and-push]
if: ${{ needs.determine-version.outputs.should_release == 'true' }}
permissions:
contents: write
pull-requests: write
steps:
- name: Clean workspace
run: |
rm -rf "$GITHUB_WORKSPACE"/* || true
rm -rf "$GITHUB_WORKSPACE"/.* || true
- name: Check out code
uses: actions/checkout@v4
with:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Configure Git for container
run: |
git config --global --add safe.directory "$GITHUB_WORKSPACE"
- name: Generate changelog
id: changelog
shell: bash
run: |
NEW_VERSION="${{ needs.determine-version.outputs.new_version }}"
LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
@@ -598,6 +656,7 @@ jobs:
FIXES=""
REFACTORS=""
# USE PROCESS SUBSTITUTION to avoid subshell variable scope issues
while IFS= read -r line; do
if [ -n "$line" ]; then
COMMIT_MSG=$(echo "$line" | cut -d'|' -f1)
@@ -631,7 +690,7 @@ jobs:
fi
fi
fi
done <<< "$COMMITS"
done < <(printf '%s\n' "$COMMITS")
# Build changelog sections
if [ -n "$FEATURES" ]; then
@@ -670,16 +729,33 @@ jobs:
prerelease: false
publish-helm-chart:
runs-on: ubuntu-latest
runs-on: self-hosted
container:
image: alpine:3.19
volumes:
- /runner-cache/apk-cache:/etc/apk/cache
needs: [determine-version, build-and-push]
if: ${{ needs.determine-version.outputs.should_release == 'true' }}
permissions:
contents: read
packages: write
steps:
- name: Clean workspace
run: |
rm -rf "$GITHUB_WORKSPACE"/* || true
rm -rf "$GITHUB_WORKSPACE"/.* || true
- name: Install dependencies
run: |
apk add --no-cache git bash curl
- name: Check out code
uses: actions/checkout@v4
- name: Configure Git for container
run: |
git config --global --add safe.directory "$GITHUB_WORKSPACE"
- name: Set up Helm
uses: azure/setup-helm@v4
with:
@@ -701,4 +777,4 @@ jobs:
- name: Push Helm chart to GHCR
run: |
VERSION="${{ needs.determine-version.outputs.new_version }}"
helm push databasus-${VERSION}.tgz oci://ghcr.io/databasus/charts
helm push databasus-${VERSION}.tgz oci://ghcr.io/databasus/charts

4
.gitignore vendored
View File

@@ -1,3 +1,4 @@
ansible/
postgresus_data/
postgresus-data/
databasus-data/
@@ -9,4 +10,5 @@ node_modules/
/articles
.DS_Store
/scripts
/scripts
.vscode/settings.json

View File

@@ -18,6 +18,13 @@ repos:
files: ^frontend/.*\.(ts|tsx|js|jsx)$
pass_filenames: false
- id: frontend-build
name: Frontend Build
entry: bash -c "cd frontend && npm run build"
language: system
files: ^frontend/.*\.(ts|tsx|js|jsx|json|css)$
pass_filenames: false
# Backend checks
- repo: local
hooks:

225
AGENTS.md
View File

@@ -237,6 +237,66 @@ func (c *ProjectController) extractProjectID(ctx *gin.Context) uuid.UUID {
---
### Boolean Naming
**Always prefix boolean variables with verbs like `is`, `has`, `was`, `should`, `can`, etc.**
This makes the code more readable and clearly indicates that the variable represents a true/false state.
#### Good Examples:
```go
type User struct {
IsActive bool
IsVerified bool
HasAccess bool
WasNotified bool
}
type BackupConfig struct {
IsEnabled bool
ShouldCompress bool
CanRetry bool
}
// Variables
isInProgress := true
wasCompleted := false
hasPermission := checkPermissions()
```
#### Bad Examples:
```go
type User struct {
Active bool // Should be: IsActive
Verified bool // Should be: IsVerified
Access bool // Should be: HasAccess
}
type BackupConfig struct {
Enabled bool // Should be: IsEnabled
Compress bool // Should be: ShouldCompress
Retry bool // Should be: CanRetry
}
// Variables
inProgress := true // Should be: isInProgress
completed := false // Should be: wasCompleted
permission := true // Should be: hasPermission
```
#### Common Boolean Prefixes:
- **is** - current state (IsActive, IsValid, IsEnabled)
- **has** - possession or presence (HasAccess, HasPermission, HasError)
- **was** - past state (WasCompleted, WasNotified, WasDeleted)
- **should** - intention or recommendation (ShouldRetry, ShouldCompress)
- **can** - capability or permission (CanRetry, CanDelete, CanEdit)
- **will** - future state (WillExpire, WillRetry)
---
### Comments
#### Guidelines
@@ -489,6 +549,134 @@ func GetOrderRepository() *repositories.OrderRepository {
}
```
#### SetupDependencies() Pattern
**All `SetupDependencies()` functions must use sync.Once to ensure idempotent execution.**
This pattern allows `SetupDependencies()` to be safely called multiple times (especially in tests) while ensuring the actual setup logic executes only once.
**Implementation Pattern:**
```go
package feature
import (
"sync"
"sync/atomic"
"databasus-backend/internal/util/logger"
)
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
// Initialize dependencies here
someService.SetDependency(otherService)
anotherService.AddListener(listener)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}
```
**Why This Pattern:**
- **Tests can call multiple times**: Test setup often calls `SetupDependencies()` multiple times without issues
- **Thread-safe**: Works correctly with concurrent calls (nanoseconds or seconds apart)
- **Idempotent**: Subsequent calls are safe, only log warning
- **No panics**: Does not break tests or production code on multiple calls
**Key Points:**
1. Check `isSetup.Load()` **before** calling `Do()` to detect previous executions
2. Set `isSetup.Store(true)` **inside** the `Do()` closure after setup completes
3. Log warning if already setup (helps identify unnecessary duplicate calls)
4. All setup logic must be inside the `Do()` closure
---
### Background Services
**All background service `Run()` methods must panic if called multiple times to prevent corrupted states.**
Background services run infinite loops and must never be started twice on the same instance. Multiple calls indicate a serious bug that would cause duplicate goroutines, resource leaks, and data corruption.
**Implementation Pattern:**
```go
package feature
import (
"context"
"fmt"
"sync"
"sync/atomic"
)
type BackgroundService struct {
// ... existing fields ...
runOnce sync.Once
hasRun atomic.Bool
}
func (s *BackgroundService) Run(ctx context.Context) {
wasAlreadyRun := s.hasRun.Load()
s.runOnce.Do(func() {
s.hasRun.Store(true)
// Existing infinite loop logic
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.doWork()
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
}
```
**Why Panic Instead of Warning:**
- **Prevents corruption**: Multiple `Run()` calls would create duplicate goroutines consuming resources
- **Fails fast**: Catches critical bugs immediately in tests and production
- **Clear indication**: Panic clearly indicates a serious programming error
- **Applies everywhere**: Same protection in tests and production
**When This Applies:**
- All background services with infinite loops
- Registry services (BackupNodesRegistry, RestoreNodesRegistry)
- Scheduler services (BackupsScheduler, RestoresScheduler)
- Worker nodes (BackuperNode, RestorerNode)
- Cleanup services (AuditLogBackgroundService, DownloadTokenBackgroundService)
**Key Points:**
1. Check `hasRun.Load()` **before** calling `Do()` to detect previous executions
2. Set `hasRun.Store(true)` **inside** the `Do()` closure before starting work
3. **Always panic** if already run (never just log warning)
4. All run logic must be inside the `Do()` closure
5. This pattern is **thread-safe** for any timing (concurrent or sequential calls)
---
### Migrations
@@ -578,6 +766,43 @@ Use these naming patterns:
- Consolidate similar test patterns across different test files
- Make tests more readable and maintainable for other developers
**Clean Up Test Data:**
- If the feature supports cleanup operations (DELETE endpoints, cleanup methods), use them in tests
- Clean up resources after test execution to avoid test data pollution
- Use `defer` statements or explicit cleanup calls at the end of tests
- Prioritize using API methods for cleanup (not direct database deletion)
- Examples:
- CRUD features: delete created records via DELETE endpoint
- File uploads: remove uploaded files
- Background jobs: stop schedulers or cancel running tasks
- Skip cleanup only when:
- Tests run in isolated transactions that auto-rollback
- Cleanup endpoint doesn't exist yet
- Test explicitly validates failure scenarios where cleanup isn't possible
**Example:**
```go
func Test_BackupLifecycle_CreateAndDelete(t *testing.T) {
router := createTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test", owner)
// Create backup config
config := createBackupConfig(t, router, workspace.ID, owner.Token)
// Cleanup at end of test
defer deleteBackupConfig(t, router, workspace.ID, config.ID, owner.Token)
// Test operations...
triggerBackup(t, router, workspace.ID, config.ID, owner.Token)
// Verify backup was created
backups := getBackups(t, router, workspace.ID, owner.Token)
assert.NotEmpty(t, backups)
}
```
#### Testing Utilities Structure
**Create `testing.go` or `testing/testing.go` files with common utilities:**

View File

@@ -251,6 +251,27 @@ fi
# PostgreSQL 17 binary paths
PG_BIN="/usr/lib/postgresql/17/bin"
# Generate runtime configuration for frontend
echo "Generating runtime configuration..."
# Detect if email is configured (both SMTP_HOST and DATABASUS_URL must be set)
if [ -n "\${SMTP_HOST:-}" ] && [ -n "\${DATABASUS_URL:-}" ]; then
IS_EMAIL_CONFIGURED="true"
else
IS_EMAIL_CONFIGURED="false"
fi
cat > /app/ui/build/runtime-config.js <<JSEOF
// Runtime configuration injected at container startup
// This file is generated dynamically and should not be edited manually
window.__RUNTIME_CONFIG__ = {
IS_CLOUD: '\${IS_CLOUD:-false}',
GITHUB_CLIENT_ID: '\${GITHUB_CLIENT_ID:-}',
GOOGLE_CLIENT_ID: '\${GOOGLE_CLIENT_ID:-}',
IS_EMAIL_CONFIGURED: '\$IS_EMAIL_CONFIGURED'
};
JSEOF
# Ensure proper ownership of data directory
echo "Setting up data directory permissions..."
mkdir -p /databasus-data/pgdata
@@ -372,6 +393,32 @@ SQL
# Start the main application
echo "Starting Databasus application..."
# Check and warn about external database/Valkey usage
if [ -n "\${DANGEROUS_EXTERNAL_DATABASE_DSN:-}" ]; then
echo ""
echo "=========================================="
echo "WARNING: Using external database"
echo "=========================================="
echo "DANGEROUS_EXTERNAL_DATABASE_DSN is set."
echo "Application will connect to external PostgreSQL instead of internal instance."
echo "Internal PostgreSQL is still running in the background."
echo "=========================================="
echo ""
fi
if [ -n "\${DANGEROUS_VALKEY_HOST:-}" ]; then
echo ""
echo "=========================================="
echo "WARNING: Using external Valkey"
echo "=========================================="
echo "DANGEROUS_VALKEY_HOST is set."
echo "Application will connect to external Valkey instead of internal instance."
echo "Internal Valkey is still running in the background."
echo "=========================================="
echo ""
fi
exec ./main
EOF

View File

@@ -6,6 +6,11 @@ DEV_DB_PASSWORD=Q1234567
ENV_MODE=development
# logging
SHOW_DB_INSTALLATION_VERIFICATION_LOGS=true
VICTORIA_LOGS_URL=http://localhost:9428
VICTORIA_LOGS_PASSWORD=devpassword
# tests
TEST_LOCALHOST=localhost
IS_SKIP_EXTERNAL_RESOURCES_TESTS=false
# db
DATABASE_DSN=host=dev-db user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
DATABASE_URL=postgres://postgres:Q1234567@dev-db:5437/databasus?sslmode=disable

3
backend/.gitignore vendored
View File

@@ -18,4 +18,5 @@ pgdata-for-restore/
temp/
cmd.exe
temp/
valkey-data/
valkey-data/
victoria-logs-data/

View File

@@ -185,6 +185,9 @@ func startServerWithGracefulShutdown(log *slog.Logger, app *gin.Engine) {
<-quit
log.Info("Shutdown signal received")
// Gracefully shutdown VictoriaLogs writer
logger.ShutdownVictoriaLogs(5 * time.Second)
// The context is used to inform the server it has 10 seconds to finish
// the request it is currently handling
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -272,6 +275,10 @@ func runBackgroundTasks(log *slog.Logger) {
backuping.GetBackupsScheduler().Run(ctx)
})
go runWithPanicLogging(log, "backup cleaner background service", func() {
backuping.GetBackupCleaner().Run(ctx)
})
go runWithPanicLogging(log, "restore background service", func() {
restoring.GetRestoresScheduler().Run(ctx)
})

View File

@@ -34,6 +34,20 @@ services:
retries: 5
start_period: 20s
# VictoriaLogs for external logging
victoria-logs:
image: victoriametrics/victoria-logs:latest
container_name: victoria-logs
ports:
- "9428:9428"
command:
- -storageDataPath=/victoria-logs-data
- -retentionPeriod=7d
- -httpAuth.password=devpassword
volumes:
- ./victoria-logs-data:/victoria-logs-data
restart: unless-stopped
# Test MinIO container
test-minio:
image: minio/minio:latest

View File

@@ -22,14 +22,26 @@ const (
type EnvVariables struct {
IsTesting bool
DatabaseDsn string `env:"DATABASE_DSN" required:"true"`
EnvMode env_utils.EnvMode `env:"ENV_MODE" required:"true"`
PostgresesInstallDir string `env:"POSTGRES_INSTALL_DIR"`
MysqlInstallDir string `env:"MYSQL_INSTALL_DIR"`
MariadbInstallDir string `env:"MARIADB_INSTALL_DIR"`
MongodbInstallDir string `env:"MONGODB_INSTALL_DIR"`
// Internal database
DatabaseDsn string `env:"DATABASE_DSN" required:"true"`
// Internal Valkey
ValkeyHost string `env:"VALKEY_HOST" required:"true"`
ValkeyPort string `env:"VALKEY_PORT" required:"true"`
ValkeyUsername string `env:"VALKEY_USERNAME" required:"true"`
ValkeyPassword string `env:"VALKEY_PASSWORD" required:"true"`
ValkeyIsSsl bool `env:"VALKEY_IS_SSL" required:"true"`
IsCloud bool `env:"IS_CLOUD"`
TestLocalhost string `env:"TEST_LOCALHOST"`
ShowDbInstallationVerificationLogs bool `env:"SHOW_DB_INSTALLATION_VERIFICATION_LOGS"`
IsSkipExternalResourcesTests bool `env:"IS_SKIP_EXTERNAL_RESOURCES_TESTS"`
IsManyNodesMode bool `env:"IS_MANY_NODES_MODE"`
IsPrimaryNode bool `env:"IS_PRIMARY_NODE"`
@@ -86,13 +98,6 @@ type EnvVariables struct {
TestMongodb70Port string `env:"TEST_MONGODB_70_PORT"`
TestMongodb82Port string `env:"TEST_MONGODB_82_PORT"`
// Valkey
ValkeyHost string `env:"VALKEY_HOST" required:"true"`
ValkeyPort string `env:"VALKEY_PORT" required:"true"`
ValkeyUsername string `env:"VALKEY_USERNAME"`
ValkeyPassword string `env:"VALKEY_PASSWORD"`
ValkeyIsSsl bool `env:"VALKEY_IS_SSL" required:"true"`
// oauth
GitHubClientID string `env:"GITHUB_CLIENT_ID"`
GitHubClientSecret string `env:"GITHUB_CLIENT_SECRET"`
@@ -109,6 +114,15 @@ type EnvVariables struct {
TestSupabaseUsername string `env:"TEST_SUPABASE_USERNAME"`
TestSupabasePassword string `env:"TEST_SUPABASE_PASSWORD"`
TestSupabaseDatabase string `env:"TEST_SUPABASE_DATABASE"`
// SMTP configuration (optional)
SMTPHost string `env:"SMTP_HOST"`
SMTPPort int `env:"SMTP_PORT"`
SMTPUser string `env:"SMTP_USER"`
SMTPPassword string `env:"SMTP_PASSWORD"`
// Application URL (optional) - used for email links
DatabasusURL string `env:"DATABASUS_URL"`
}
var (
@@ -174,6 +188,16 @@ func loadEnvVariables() {
env.ShowDbInstallationVerificationLogs = true
}
// Set default value for IsSkipExternalTests if not defined
if os.Getenv("IS_SKIP_EXTERNAL_RESOURCES_TESTS") == "" {
env.IsSkipExternalResourcesTests = false
}
// Set default value for IsCloud if not defined
if os.Getenv("IS_CLOUD") == "" {
env.IsCloud = false
}
for _, arg := range os.Args {
if strings.Contains(arg, "test") {
env.IsTesting = true
@@ -181,6 +205,14 @@ func loadEnvVariables() {
}
}
// Check for external database override
if externalDsn := os.Getenv("DANGEROUS_EXTERNAL_DATABASE_DSN"); externalDsn != "" {
log.Warn(
"Using DANGEROUS_EXTERNAL_DATABASE_DSN - connecting to external database instead of internal PostgreSQL",
)
env.DatabaseDsn = externalDsn
}
if env.DatabaseDsn == "" {
log.Error("DATABASE_DSN is empty")
os.Exit(1)
@@ -237,6 +269,10 @@ func loadEnvVariables() {
env.IsProcessingNode = true
}
if env.TestLocalhost == "" {
env.TestLocalhost = "localhost"
}
// Valkey
if env.ValkeyHost == "" {
log.Error("VALKEY_HOST is empty")
@@ -247,6 +283,27 @@ func loadEnvVariables() {
os.Exit(1)
}
// Check for external Valkey override
if externalValkeyHost := os.Getenv("DANGEROUS_VALKEY_HOST"); externalValkeyHost != "" {
log.Warn(
"Using DANGEROUS_VALKEY_* variables - connecting to external Valkey instead of internal instance",
)
env.ValkeyHost = externalValkeyHost
if externalValkeyPort := os.Getenv("DANGEROUS_VALKEY_PORT"); externalValkeyPort != "" {
env.ValkeyPort = externalValkeyPort
}
if externalValkeyUsername := os.Getenv("DANGEROUS_VALKEY_USERNAME"); externalValkeyUsername != "" {
env.ValkeyUsername = externalValkeyUsername
}
if externalValkeyPassword := os.Getenv("DANGEROUS_VALKEY_PASSWORD"); externalValkeyPassword != "" {
env.ValkeyPassword = externalValkeyPassword
}
if externalValkeyIsSsl := os.Getenv("DANGEROUS_VALKEY_IS_SSL"); externalValkeyIsSsl != "" {
env.ValkeyIsSsl = externalValkeyIsSsl == "true"
}
}
// Store the data and temp folders one level below the root
// (projectRoot/databasus-data -> /databasus-data)
env.DataFolder = filepath.Join(filepath.Dir(backendRoot), "databasus-data", "backups")

View File

@@ -2,34 +2,50 @@ package audit_logs
import (
"context"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
)
type AuditLogBackgroundService struct {
auditLogService *AuditLogService
logger *slog.Logger
runOnce sync.Once
hasRun atomic.Bool
}
func (s *AuditLogBackgroundService) Run(ctx context.Context) {
s.logger.Info("Starting audit log cleanup background service")
wasAlreadyRun := s.hasRun.Load()
if ctx.Err() != nil {
return
}
s.runOnce.Do(func() {
s.hasRun.Store(true)
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
s.logger.Info("Starting audit log cleanup background service")
for {
select {
case <-ctx.Done():
if ctx.Err() != nil {
return
case <-ticker.C:
if err := s.cleanOldAuditLogs(); err != nil {
s.logger.Error("Failed to clean old audit logs", "error", err)
}
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.cleanOldAuditLogs(); err != nil {
s.logger.Error("Failed to clean old audit logs", "error", err)
}
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
}

View File

@@ -1,6 +1,9 @@
package audit_logs
import (
"sync"
"sync/atomic"
users_services "databasus-backend/internal/features/users/services"
"databasus-backend/internal/util/logger"
)
@@ -14,8 +17,10 @@ var auditLogController = &AuditLogController{
auditLogService,
}
var auditLogBackgroundService = &AuditLogBackgroundService{
auditLogService,
logger.GetLogger(),
auditLogService: auditLogService,
logger: logger.GetLogger(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
func GetAuditLogService() *AuditLogService {
@@ -30,8 +35,23 @@ func GetAuditLogBackgroundService() *AuditLogBackgroundService {
return auditLogBackgroundService
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
users_services.GetUserService().SetAuditLogWriter(auditLogService)
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
users_services.GetUserService().SetAuditLogWriter(auditLogService)
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -2,6 +2,17 @@ package backuping
import (
"context"
"errors"
"fmt"
"log/slog"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"databasus-backend/internal/config"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
@@ -10,14 +21,6 @@ import (
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
workspaces_services "databasus-backend/internal/features/workspaces/services"
util_encryption "databasus-backend/internal/util/encryption"
"errors"
"fmt"
"log/slog"
"slices"
"strings"
"time"
"github.com/google/uuid"
)
const (
@@ -40,66 +43,81 @@ type BackuperNode struct {
nodeID uuid.UUID
lastHeartbeat time.Time
runOnce sync.Once
hasRun atomic.Bool
}
func (n *BackuperNode) Run(ctx context.Context) {
n.lastHeartbeat = time.Now().UTC()
wasAlreadyRun := n.hasRun.Load()
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
n.runOnce.Do(func() {
n.hasRun.Store(true)
backupNode := BackupNode{
ID: n.nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: time.Now().UTC(),
}
n.lastHeartbeat = time.Now().UTC()
if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
n.logger.Error("Failed to register node in registry", "error", err)
panic(err)
}
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
n.MakeBackup(backupID, isCallNotifier)
if err := n.backupNodesRegistry.PublishBackupCompletion(n.nodeID, backupID); err != nil {
n.logger.Error(
"Failed to publish backup completion",
"error",
err,
"backupID",
backupID,
)
backupNode := BackupNode{
ID: n.nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: time.Now().UTC(),
}
}
err := n.backupNodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID, backupHandler)
if err != nil {
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
panic(err)
}
defer func() {
if err := n.backupNodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
n.logger.Error("Failed to register node in registry", "error", err)
panic(err)
}
}()
ticker := time.NewTicker(heartbeatTickerInterval)
defer ticker.Stop()
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
go func() {
n.MakeBackup(backupID, isCallNotifier)
if err := n.backupNodesRegistry.PublishBackupCompletion(n.nodeID, backupID); err != nil {
n.logger.Error(
"Failed to publish backup completion",
"error",
err,
"backupID",
backupID,
)
}
}()
}
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
for {
select {
case <-ctx.Done():
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
if err := n.backupNodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
n.logger.Error("Failed to unregister node from registry", "error", err)
err := n.backupNodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID, backupHandler)
if err != nil {
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
panic(err)
}
defer func() {
if err := n.backupNodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
}
}()
return
case <-ticker.C:
n.sendHeartbeat(&backupNode)
ticker := time.NewTicker(heartbeatTickerInterval)
defer ticker.Stop()
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
for {
select {
case <-ctx.Done():
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
if err := n.backupNodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
n.logger.Error("Failed to unregister node from registry", "error", err)
}
return
case <-ticker.C:
n.sendHeartbeat(&backupNode)
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", n))
}
}
@@ -141,21 +159,41 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
start := time.Now().UTC()
ctx, cancel := context.WithCancel(context.Background())
n.backupCancelManager.RegisterTask(backup.ID, cancel)
defer n.backupCancelManager.UnregisterTask(backup.ID)
backupProgressListener := func(
completedMBs float64,
) {
backup.BackupSizeMb = completedMBs
backup.BackupDurationMs = time.Since(start).Milliseconds()
// Check size limit (0 = unlimited)
if backupConfig.MaxBackupSizeMB > 0 &&
completedMBs > float64(backupConfig.MaxBackupSizeMB) {
errMsg := fmt.Sprintf(
"backup size (%.2f MB) exceeded maximum allowed size (%d MB)",
completedMBs,
backupConfig.MaxBackupSizeMB,
)
backup.Status = backups_core.BackupStatusFailed
backup.IsSkipRetry = true
backup.FailMessage = &errMsg
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to save backup with size exceeded error", "error", err)
}
cancel() // Cancel the backup context
return
}
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to update backup progress", "error", err)
}
}
ctx, cancel := context.WithCancel(context.Background())
n.backupCancelManager.RegisterTask(backup.ID, cancel)
defer n.backupCancelManager.UnregisterTask(backup.ID)
backupMetadata, err := n.createBackupUseCase.Execute(
ctx,
backup.ID,
@@ -165,6 +203,29 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
backupProgressListener,
)
if err != nil {
// Check if backup was already marked as failed by progress listener (e.g., size limit exceeded)
// If so, skip error handling to avoid overwriting the status
currentBackup, fetchErr := n.backupRepository.FindByID(backup.ID)
if fetchErr == nil && currentBackup.Status == backups_core.BackupStatusFailed {
n.logger.Warn(
"Backup already marked as failed by progress listener, skipping error handling",
"backupId",
backup.ID,
"failMessage",
*currentBackup.FailMessage,
)
// Still call notification for size limit failures
n.SendBackupNotification(
backupConfig,
currentBackup,
backups_config.NotificationBackupFailed,
currentBackup.FailMessage,
)
return
}
errMsg := err.Error()
// Log detailed error information for debugging

View File

@@ -1,13 +1,10 @@
package backuping
import (
"context"
"errors"
"strings"
"testing"
"time"
common "databasus-backend/internal/features/backups/backups/common"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -18,7 +15,6 @@ import (
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
cache_utils "databasus-backend/internal/util/cache"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
@@ -158,35 +154,120 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
})
}
type CreateFailedBackupUsecase struct {
}
func Test_BackupSizeLimits(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
func (uc *CreateFailedBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
backupProgressListener(10)
return nil, errors.New("backup failed")
}
defer func() {
// cleanup backups first
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
type CreateSuccessBackupUsecase struct{}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
func (uc *CreateSuccessBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
backupProgressListener(10)
return &common.BackupMetadata{
EncryptionSalt: nil,
EncryptionIV: nil,
Encryption: backups_config.BackupEncryptionNone,
}, nil
t.Run("UnlimitedSize_MaxBackupSizeMBIsZero_BackupCompletes", func(t *testing.T) {
// Enable backups with unlimited size (0)
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
backupConfig.MaxBackupSizeMB = 0 // unlimited
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backuperNode := CreateTestBackuperNode()
backuperNode.createBackupUseCase = &CreateLargeBackupUsecase{}
// Create a backup record
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
backuperNode.MakeBackup(backup.ID, false)
// Verify backup completed successfully even with large size
updatedBackup, err := backupRepository.FindByID(backup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusCompleted, updatedBackup.Status)
assert.Equal(t, float64(10000), updatedBackup.BackupSizeMb)
assert.Nil(t, updatedBackup.FailMessage)
})
t.Run("SizeExceeded_BackupFailedWithIsSkipRetry", func(t *testing.T) {
// Enable backups with 5 MB limit
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
backupConfig.MaxBackupSizeMB = 5
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backuperNode := CreateTestBackuperNode()
backuperNode.createBackupUseCase = &CreateProgressiveBackupUsecase{}
// Create a backup record
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
backuperNode.MakeBackup(backup.ID, false)
// Verify backup was marked as failed with IsSkipRetry=true
updatedBackup, err := backupRepository.FindByID(backup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusFailed, updatedBackup.Status)
assert.True(t, updatedBackup.IsSkipRetry)
assert.NotNil(t, updatedBackup.FailMessage)
assert.Contains(t, *updatedBackup.FailMessage, "exceeded maximum allowed size")
assert.Contains(t, *updatedBackup.FailMessage, "10.00 MB")
assert.Contains(t, *updatedBackup.FailMessage, "5 MB")
assert.Greater(t, updatedBackup.BackupSizeMb, float64(5))
})
t.Run("SizeWithinLimit_BackupCompletes", func(t *testing.T) {
// Enable backups with 100 MB limit
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
backupConfig.MaxBackupSizeMB = 100
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backuperNode := CreateTestBackuperNode()
backuperNode.createBackupUseCase = &CreateMediumBackupUsecase{}
// Create a backup record
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
backuperNode.MakeBackup(backup.ID, false)
// Verify backup completed successfully
updatedBackup, err := backupRepository.FindByID(backup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusCompleted, updatedBackup.Status)
assert.Equal(t, float64(50), updatedBackup.BackupSizeMb)
assert.Nil(t, updatedBackup.FailMessage)
})
}

View File

@@ -0,0 +1,242 @@
package backuping
import (
"context"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/storages"
util_encryption "databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/period"
)
const (
cleanerTickerInterval = 1 * time.Minute
)
type BackupCleaner struct {
backupRepository *backups_core.BackupRepository
storageService *storages.StorageService
backupConfigService *backups_config.BackupConfigService
fieldEncryptor util_encryption.FieldEncryptor
logger *slog.Logger
backupRemoveListeners []backups_core.BackupRemoveListener
runOnce sync.Once
hasRun atomic.Bool
}
func (c *BackupCleaner) Run(ctx context.Context) {
wasAlreadyRun := c.hasRun.Load()
c.runOnce.Do(func() {
c.hasRun.Store(true)
if ctx.Err() != nil {
return
}
ticker := time.NewTicker(cleanerTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := c.cleanOldBackups(); err != nil {
c.logger.Error("Failed to clean old backups", "error", err)
}
if err := c.cleanExceededBackups(); err != nil {
c.logger.Error("Failed to clean exceeded backups", "error", err)
}
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", c))
}
}
func (c *BackupCleaner) DeleteBackup(backup *backups_core.Backup) error {
for _, listener := range c.backupRemoveListeners {
if err := listener.OnBeforeBackupRemove(backup); err != nil {
return err
}
}
storage, err := c.storageService.GetStorageByID(backup.StorageID)
if err != nil {
return err
}
err = storage.DeleteFile(c.fieldEncryptor, backup.ID)
if err != nil {
// we do not return error here, because sometimes clean up performed
// before unavailable storage removal or change - therefore we should
// proceed even in case of error. It's possible that some S3 or
// storage is not available yet, it should not block us
c.logger.Error("Failed to delete backup file", "error", err)
}
return c.backupRepository.DeleteByID(backup.ID)
}
func (c *BackupCleaner) AddBackupRemoveListener(listener backups_core.BackupRemoveListener) {
c.backupRemoveListeners = append(c.backupRemoveListeners, listener)
}
func (c *BackupCleaner) cleanOldBackups() error {
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
backupStorePeriod := backupConfig.StorePeriod
if backupStorePeriod == period.PeriodForever {
continue
}
storeDuration := backupStorePeriod.ToDuration()
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
oldBackups, err := c.backupRepository.FindBackupsBeforeDate(
backupConfig.DatabaseID,
dateBeforeBackupsShouldBeDeleted,
)
if err != nil {
c.logger.Error(
"Failed to find old backups for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
continue
}
for _, backup := range oldBackups {
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
continue
}
c.logger.Info(
"Deleted old backup",
"backupId",
backup.ID,
"databaseId",
backupConfig.DatabaseID,
)
}
}
return nil
}
func (c *BackupCleaner) cleanExceededBackups() error {
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
if backupConfig.MaxBackupsTotalSizeMB <= 0 {
continue
}
if err := c.cleanExceededBackupsForDatabase(
backupConfig.DatabaseID,
backupConfig.MaxBackupsTotalSizeMB,
); err != nil {
c.logger.Error(
"Failed to clean exceeded backups for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
continue
}
}
return nil
}
func (c *BackupCleaner) cleanExceededBackupsForDatabase(
databaseID uuid.UUID,
limitperDbMB int64,
) error {
for {
backupsTotalSizeMB, err := c.backupRepository.GetTotalSizeByDatabase(databaseID)
if err != nil {
return err
}
if backupsTotalSizeMB <= float64(limitperDbMB) {
break
}
oldestBackups, err := c.backupRepository.FindOldestByDatabaseExcludingInProgress(
databaseID,
1,
)
if err != nil {
return err
}
if len(oldestBackups) == 0 {
c.logger.Warn(
"No backups to delete but still over limit",
"databaseId",
databaseID,
"totalSizeMB",
backupsTotalSizeMB,
"limitMB",
limitperDbMB,
)
break
}
backup := oldestBackups[0]
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error(
"Failed to delete exceeded backup",
"backupId",
backup.ID,
"databaseId",
databaseID,
"error",
err,
)
return err
}
c.logger.Info(
"Deleted exceeded backup",
"backupId",
backup.ID,
"databaseId",
databaseID,
"backupSizeMB",
backup.BackupSizeMb,
"totalSizeMB",
backupsTotalSizeMB,
"limitMB",
limitperDbMB,
)
}
return nil
}

View File

@@ -0,0 +1,595 @@
package backuping
import (
"testing"
"time"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/storage"
"databasus-backend/internal/util/period"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_CleanOldBackups_DeletesBackupsOlderThanStorePeriod(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Create backup interval
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Create backups with different ages
now := time.Now().UTC()
oldBackup1 := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10,
CreatedAt: now.Add(-10 * 24 * time.Hour), // 10 days old
}
oldBackup2 := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10,
CreatedAt: now.Add(-8 * 24 * time.Hour), // 8 days old
}
recentBackup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10,
CreatedAt: now.Add(-3 * 24 * time.Hour), // 3 days old
}
err = backupRepository.Save(oldBackup1)
assert.NoError(t, err)
err = backupRepository.Save(oldBackup2)
assert.NoError(t, err)
err = backupRepository.Save(recentBackup)
assert.NoError(t, err)
// Run cleanup
cleaner := GetBackupCleaner()
err = cleaner.cleanOldBackups()
assert.NoError(t, err)
// Verify old backups deleted, recent backup remains
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Equal(t, 1, len(remainingBackups))
assert.Equal(t, recentBackup.ID, remainingBackups[0].ID)
}
func Test_CleanOldBackups_SkipsDatabaseWithForeverStorePeriod(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Create backup interval
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodForever,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Create very old backup
oldBackup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10,
CreatedAt: time.Now().UTC().Add(-365 * 24 * time.Hour), // 1 year old
}
err = backupRepository.Save(oldBackup)
assert.NoError(t, err)
// Run cleanup
cleaner := GetBackupCleaner()
err = cleaner.cleanOldBackups()
assert.NoError(t, err)
// Verify backup still exists
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Equal(t, 1, len(remainingBackups))
assert.Equal(t, oldBackup.ID, remainingBackups[0].ID)
}
func Test_CleanExceededBackups_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Create backup interval
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 100, // 100 MB limit
BackupIntervalID: interval.ID,
BackupInterval: interval,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Create 3 backups totaling 50MB (under limit)
for i := 0; i < 3; i++ {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 16.67,
CreatedAt: time.Now().UTC().Add(-time.Duration(i) * time.Hour),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
}
// Run cleanup
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
assert.NoError(t, err)
// Verify all backups remain
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Equal(t, 3, len(remainingBackups))
}
func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Create backup interval
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 30, // 30 MB limit
BackupIntervalID: interval.ID,
BackupInterval: interval,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Create 5 backups of 10MB each (total 50MB, over 30MB limit)
now := time.Now().UTC()
var backupIDs []uuid.UUID
for i := 0; i < 5; i++ {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10,
CreatedAt: now.Add(-time.Duration(4-i) * time.Hour), // Oldest first
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
backupIDs = append(backupIDs, backup.ID)
}
// Run cleanup
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
assert.NoError(t, err)
// Verify 2 oldest backups deleted, 3 newest remain
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Equal(t, 3, len(remainingBackups))
// Check that the newest 3 backups remain
remainingIDs := make(map[uuid.UUID]bool)
for _, backup := range remainingBackups {
remainingIDs[backup.ID] = true
}
assert.False(t, remainingIDs[backupIDs[0]]) // Oldest deleted
assert.False(t, remainingIDs[backupIDs[1]]) // 2nd oldest deleted
assert.True(t, remainingIDs[backupIDs[2]]) // 3rd remains
assert.True(t, remainingIDs[backupIDs[3]]) // 4th remains
assert.True(t, remainingIDs[backupIDs[4]]) // Newest remains
}
func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Create backup interval
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 50, // 50 MB limit
BackupIntervalID: interval.ID,
BackupInterval: interval,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
now := time.Now().UTC()
// Create 3 completed backups of 30MB each
completedBackups := make([]*backups_core.Backup, 3)
for i := 0; i < 3; i++ {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 30,
CreatedAt: now.Add(-time.Duration(3-i) * time.Hour),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
completedBackups[i] = backup
}
// Create 1 in-progress backup (should be excluded from size calculation and deletion)
inProgressBackup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
BackupSizeMb: 10,
CreatedAt: now,
}
err = backupRepository.Save(inProgressBackup)
assert.NoError(t, err)
// Run cleanup
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
assert.NoError(t, err)
// Verify: only completed backups deleted, in-progress remains
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
// Should have in-progress + 1 completed (total 40MB completed + 10MB in-progress)
assert.GreaterOrEqual(t, len(remainingBackups), 2)
// Verify in-progress backup still exists
var inProgressFound bool
for _, backup := range remainingBackups {
if backup.ID == inProgressBackup.ID {
inProgressFound = true
assert.Equal(t, backups_core.BackupStatusInProgress, backup.Status)
}
}
assert.True(t, inProgressFound, "In-progress backup should not be deleted")
}
func Test_CleanExceededBackups_WithZeroLimit_SkipsDatabase(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Create backup interval
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 0, // No size limit
BackupIntervalID: interval.ID,
BackupInterval: interval,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Create large backups
for i := 0; i < 10; i++ {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 100,
CreatedAt: time.Now().UTC().Add(-time.Duration(i) * time.Hour),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
}
// Run cleanup
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
assert.NoError(t, err)
// Verify all backups remain
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Equal(t, 10, len(remainingBackups))
}
func Test_GetTotalSizeByDatabase_CalculatesCorrectly(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Create completed backups
completedBackup1 := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10.5,
CreatedAt: time.Now().UTC(),
}
completedBackup2 := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 20.3,
CreatedAt: time.Now().UTC(),
}
// Create failed backup (should be included)
failedBackup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusFailed,
BackupSizeMb: 5.2,
CreatedAt: time.Now().UTC(),
}
// Create in-progress backup (should be excluded)
inProgressBackup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
BackupSizeMb: 100,
CreatedAt: time.Now().UTC(),
}
err := backupRepository.Save(completedBackup1)
assert.NoError(t, err)
err = backupRepository.Save(completedBackup2)
assert.NoError(t, err)
err = backupRepository.Save(failedBackup)
assert.NoError(t, err)
err = backupRepository.Save(inProgressBackup)
assert.NoError(t, err)
// Calculate total size
totalSize, err := backupRepository.GetTotalSizeByDatabase(database.ID)
assert.NoError(t, err)
// Should be 10.5 + 20.3 + 5.2 = 36.0 (excluding in-progress 100)
assert.InDelta(t, 36.0, totalSize, 0.1)
}
// Mock listener for testing
type mockBackupRemoveListener struct {
onBeforeBackupRemove func(*backups_core.Backup) error
}
func (m *mockBackupRemoveListener) OnBeforeBackupRemove(backup *backups_core.Backup) error {
if m.onBeforeBackupRemove != nil {
return m.onBeforeBackupRemove(backup)
}
return nil
}
// Test_DeleteBackup_WhenStorageDeleteFails_BackupStillRemovedFromDatabase verifies resilience
// when storage becomes unavailable. Even if storage.DeleteFile fails (e.g., storage is offline,
// credentials changed, or storage was deleted), the backup record should still be removed from
// the database. This prevents orphaned backup records when storage is no longer accessible.
func Test_DeleteBackup_WhenStorageDeleteFails_BackupStillRemovedFromDatabase(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
testStorage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, testStorage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(testStorage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: testStorage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10,
CreatedAt: time.Now().UTC(),
}
err := backupRepository.Save(backup)
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.DeleteBackup(backup)
assert.NoError(t, err, "DeleteBackup should succeed even when storage file doesn't exist")
deletedBackup, err := backupRepository.FindByID(backup.ID)
assert.Error(t, err, "Backup should not exist in database")
assert.Nil(t, deletedBackup)
}
func createTestInterval() *intervals.Interval {
timeOfDay := "04:00"
interval := &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
err := storage.GetDb().Create(interval).Error
if err != nil {
panic(err)
}
return interval
}

View File

@@ -1,6 +1,12 @@
package backuping
import (
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/usecases"
backups_config "databasus-backend/internal/features/backups/config"
@@ -12,21 +18,31 @@ import (
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
"time"
"github.com/google/uuid"
)
var backupRepository = &backups_core.BackupRepository{}
var taskCancelManager = tasks_cancellation.GetTaskCancelManager()
var backupCleaner = &BackupCleaner{
backupRepository: backupRepository,
storageService: storages.GetStorageService(),
backupConfigService: backups_config.GetBackupConfigService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
logger: logger.GetLogger(),
backupRemoveListeners: []backups_core.BackupRemoveListener{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
var backupNodesRegistry = &BackupNodesRegistry{
cache_utils.GetValkeyClient(),
logger.GetLogger(),
cache_utils.DefaultCacheTimeout,
cache_utils.NewPubSubManager(),
cache_utils.NewPubSubManager(),
client: cache_utils.GetValkeyClient(),
logger: logger.GetLogger(),
timeout: cache_utils.DefaultCacheTimeout,
pubsubBackups: cache_utils.NewPubSubManager(),
pubsubCompletions: cache_utils.NewPubSubManager(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
func getNodeID() uuid.UUID {
@@ -34,31 +50,34 @@ func getNodeID() uuid.UUID {
}
var backuperNode = &BackuperNode{
databases.GetDatabaseService(),
encryption.GetFieldEncryptor(),
workspaces_services.GetWorkspaceService(),
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
notifiers.GetNotifierService(),
taskCancelManager,
backupNodesRegistry,
logger.GetLogger(),
usecases.GetCreateBackupUsecase(),
getNodeID(),
time.Time{},
databaseService: databases.GetDatabaseService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
workspaceService: workspaces_services.GetWorkspaceService(),
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
notificationSender: notifiers.GetNotifierService(),
backupCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
logger: logger.GetLogger(),
createBackupUseCase: usecases.GetCreateBackupUsecase(),
nodeID: getNodeID(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
var backupsScheduler = &BackupsScheduler{
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
taskCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
lastBackupTime: time.Now().UTC(),
logger: logger.GetLogger(),
backupToNodeRelations: make(map[uuid.UUID]BackupToNodeRelation),
backuperNode: backuperNode,
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
func GetBackupsScheduler() *BackupsScheduler {
@@ -72,3 +91,7 @@ func GetBackuperNode() *BackuperNode {
func GetBackupNodesRegistry() *BackupNodesRegistry {
return backupNodesRegistry
}
func GetBackupCleaner() *BackupCleaner {
return backupCleaner
}

View File

@@ -1,8 +1,18 @@
package backuping
import (
"databasus-backend/internal/features/notifiers"
"context"
"errors"
"sync/atomic"
"time"
common "databasus-backend/internal/features/backups/backups/common"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
"github.com/google/uuid"
"github.com/stretchr/testify/mock"
)
@@ -17,3 +27,168 @@ func (m *MockNotificationSender) SendNotification(
) {
m.Called(notifier, title, message)
}
type CreateFailedBackupUsecase struct{}
func (uc *CreateFailedBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
backupProgressListener(10)
return nil, errors.New("backup failed")
}
type CreateSuccessBackupUsecase struct{}
func (uc *CreateSuccessBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
backupProgressListener(10)
return &common.BackupMetadata{
EncryptionSalt: nil,
EncryptionIV: nil,
Encryption: backups_config.BackupEncryptionNone,
}, nil
}
// CreateLargeBackupUsecase simulates a large backup (10000 MB)
type CreateLargeBackupUsecase struct{}
func (uc *CreateLargeBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
backupProgressListener(10000)
return &common.BackupMetadata{
EncryptionSalt: nil,
EncryptionIV: nil,
Encryption: backups_config.BackupEncryptionNone,
}, nil
}
// CreateProgressiveBackupUsecase simulates progressive size updates that exceed limit
type CreateProgressiveBackupUsecase struct{}
func (uc *CreateProgressiveBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
// Simulate progressive backup that grows beyond limit
backupProgressListener(1)
if ctx.Err() != nil {
return nil, ctx.Err()
}
backupProgressListener(3)
if ctx.Err() != nil {
return nil, ctx.Err()
}
backupProgressListener(5)
if ctx.Err() != nil {
return nil, ctx.Err()
}
backupProgressListener(10) // This exceeds the 5 MB limit
if ctx.Err() != nil {
return nil, ctx.Err()
}
// Should not reach here due to cancellation
return &common.BackupMetadata{
EncryptionSalt: nil,
EncryptionIV: nil,
Encryption: backups_config.BackupEncryptionNone,
}, nil
}
// CreateMediumBackupUsecase simulates a 50 MB backup
type CreateMediumBackupUsecase struct{}
func (uc *CreateMediumBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
backupProgressListener(50)
return &common.BackupMetadata{
EncryptionSalt: nil,
EncryptionIV: nil,
Encryption: backups_config.BackupEncryptionNone,
}, nil
}
// MockTrackingBackupUsecase tracks backup use case calls for testing parallel execution
type MockTrackingBackupUsecase struct {
callCount atomic.Int32
calledBackupIDs chan uuid.UUID
}
func NewMockTrackingBackupUsecase() *MockTrackingBackupUsecase {
return &MockTrackingBackupUsecase{
calledBackupIDs: make(chan uuid.UUID, 10),
}
}
func (m *MockTrackingBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
m.callCount.Add(1)
// Send backup ID to channel (non-blocking)
select {
case m.calledBackupIDs <- backupID:
default:
}
// Simulate backup work
time.Sleep(100 * time.Millisecond)
backupProgressListener(10)
return &common.BackupMetadata{
EncryptionSalt: nil,
EncryptionIV: nil,
Encryption: backups_config.BackupEncryptionNone,
}, nil
}
func (m *MockTrackingBackupUsecase) GetCallCount() int32 {
return m.callCount.Load()
}
func (m *MockTrackingBackupUsecase) GetCalledBackupIDs() []uuid.UUID {
ids := []uuid.UUID{}
for {
select {
case id := <-m.calledBackupIDs:
ids = append(ids, id)
default:
return ids
}
}
}

View File

@@ -6,6 +6,8 @@ import (
"fmt"
"log/slog"
"strings"
"sync"
"sync/atomic"
"time"
cache_utils "databasus-backend/internal/util/cache"
@@ -47,24 +49,37 @@ type BackupNodesRegistry struct {
timeout time.Duration
pubsubBackups *cache_utils.PubSubManager
pubsubCompletions *cache_utils.PubSubManager
runOnce sync.Once
hasRun atomic.Bool
}
func (r *BackupNodesRegistry) Run(ctx context.Context) {
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
}
wasAlreadyRun := r.hasRun.Load()
ticker := time.NewTicker(cleanupTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes", "error", err)
r.runOnce.Do(func() {
r.hasRun.Store(true)
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
}
ticker := time.NewTicker(cleanupTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes", "error", err)
}
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", r))
}
}

View File

@@ -4,6 +4,8 @@ import (
"context"
"encoding/json"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
@@ -594,11 +596,13 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) {
func createTestRegistry() *BackupNodesRegistry {
return &BackupNodesRegistry{
cache_utils.GetValkeyClient(),
logger.GetLogger(),
cache_utils.DefaultCacheTimeout,
cache_utils.NewPubSubManager(),
cache_utils.NewPubSubManager(),
client: cache_utils.GetValkeyClient(),
logger: logger.GetLogger(),
timeout: cache_utils.DefaultCacheTimeout,
pubsubBackups: cache_utils.NewPubSubManager(),
pubsubCompletions: cache_utils.NewPubSubManager(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}

View File

@@ -2,18 +2,18 @@ package backuping
import (
"context"
"databasus-backend/internal/config"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/storages"
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/period"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"databasus-backend/internal/config"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
)
const (
@@ -25,7 +25,6 @@ const (
type BackupsScheduler struct {
backupRepository *backups_core.BackupRepository
backupConfigService *backups_config.BackupConfigService
storageService *storages.StorageService
taskCancelManager *task_cancellation.TaskCancelManager
backupNodesRegistry *BackupNodesRegistry
@@ -34,59 +33,68 @@ type BackupsScheduler struct {
backupToNodeRelations map[uuid.UUID]BackupToNodeRelation
backuperNode *BackuperNode
runOnce sync.Once
hasRun atomic.Bool
}
func (s *BackupsScheduler) Run(ctx context.Context) {
s.lastBackupTime = time.Now().UTC()
wasAlreadyRun := s.hasRun.Load()
if config.GetEnv().IsManyNodesMode {
// wait other nodes to start
time.Sleep(schedulerStartupDelay)
}
s.runOnce.Do(func() {
s.hasRun.Store(true)
if err := s.failBackupsInProgress(); err != nil {
s.logger.Error("Failed to fail backups in progress", "error", err)
panic(err)
}
s.lastBackupTime = time.Now().UTC()
err := s.backupNodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted)
if err != nil {
s.logger.Error("Failed to subscribe to backup completions", "error", err)
panic(err)
}
defer func() {
if err := s.backupNodesRegistry.UnsubscribeForBackupsCompletions(); err != nil {
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
if config.GetEnv().IsManyNodesMode {
// wait other nodes to start
time.Sleep(schedulerStartupDelay)
}
}()
if ctx.Err() != nil {
return
}
if err := s.failBackupsInProgress(); err != nil {
s.logger.Error("Failed to fail backups in progress", "error", err)
panic(err)
}
ticker := time.NewTicker(schedulerTickerInterval)
defer ticker.Stop()
err := s.backupNodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted)
if err != nil {
s.logger.Error("Failed to subscribe to backup completions", "error", err)
panic(err)
}
for {
select {
case <-ctx.Done():
defer func() {
if err := s.backupNodesRegistry.UnsubscribeForBackupsCompletions(); err != nil {
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
}
}()
if ctx.Err() != nil {
return
case <-ticker.C:
if err := s.cleanOldBackups(); err != nil {
s.logger.Error("Failed to clean old backups", "error", err)
}
if err := s.checkDeadNodesAndFailBackups(); err != nil {
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
}
if err := s.runPendingBackups(); err != nil {
s.logger.Error("Failed to run pending backups", "error", err)
}
s.lastBackupTime = time.Now().UTC()
}
ticker := time.NewTicker(schedulerTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.checkDeadNodesAndFailBackups(); err != nil {
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
}
if err := s.runPendingBackups(); err != nil {
s.logger.Error("Failed to run pending backups", "error", err)
}
s.lastBackupTime = time.Now().UTC()
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
}
@@ -95,49 +103,6 @@ func (s *BackupsScheduler) IsSchedulerRunning() bool {
return s.lastBackupTime.After(time.Now().UTC().Add(-schedulerHealthcheckThreshold))
}
func (s *BackupsScheduler) failBackupsInProgress() error {
backupsInProgress, err := s.backupRepository.FindByStatus(backups_core.BackupStatusInProgress)
if err != nil {
return err
}
for _, backup := range backupsInProgress {
if err := s.taskCancelManager.CancelTask(backup.ID); err != nil {
s.logger.Error(
"Failed to cancel backup via task cancel manager",
"backupId",
backup.ID,
"error",
err,
)
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(backup.DatabaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
continue
}
failMessage := "Backup failed due to application restart"
backup.FailMessage = &failMessage
backup.Status = backups_core.BackupStatusFailed
backup.BackupSizeMb = 0
s.backuperNode.SendBackupNotification(
backupConfig,
backup,
backups_config.NotificationBackupFailed,
&failMessage,
)
if err := s.backupRepository.Save(backup); err != nil {
return err
}
}
return nil
}
func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool) {
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(databaseID)
if err != nil {
@@ -150,6 +115,33 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
return
}
// Check for existing in-progress backups
inProgressBackups, err := s.backupRepository.FindByDatabaseIdAndStatus(
databaseID,
backups_core.BackupStatusInProgress,
)
if err != nil {
s.logger.Error(
"Failed to check for in-progress backups",
"databaseId",
databaseID,
"error",
err,
)
return
}
if len(inProgressBackups) > 0 {
s.logger.Warn(
"Backup already in progress for database, skipping new backup",
"databaseId",
databaseID,
"existingBackupId",
inProgressBackups[0].ID,
)
return
}
leastBusyNodeID, err := s.calculateLeastBusyNode()
if err != nil {
s.logger.Error(
@@ -162,6 +154,7 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
return
}
fmt.Println("make backup")
backup := &backups_core.Backup{
DatabaseID: backupConfig.DatabaseID,
StorageID: *backupConfig.StorageID,
@@ -250,6 +243,10 @@ func (s *BackupsScheduler) GetRemainedBackupTryCount(lastBackup *backups_core.Ba
return 0
}
if lastBackup.IsSkipRetry {
return 0
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(lastBackup.DatabaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
@@ -282,74 +279,6 @@ func (s *BackupsScheduler) GetRemainedBackupTryCount(lastBackup *backups_core.Ba
return maxFailedTriesCount - len(lastFailedBackups)
}
func (s *BackupsScheduler) cleanOldBackups() error {
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
backupStorePeriod := backupConfig.StorePeriod
if backupStorePeriod == period.PeriodForever {
continue
}
storeDuration := backupStorePeriod.ToDuration()
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
oldBackups, err := s.backupRepository.FindBackupsBeforeDate(
backupConfig.DatabaseID,
dateBeforeBackupsShouldBeDeleted,
)
if err != nil {
s.logger.Error(
"Failed to find old backups for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
continue
}
for _, backup := range oldBackups {
storage, err := s.storageService.GetStorageByID(backup.StorageID)
if err != nil {
s.logger.Error(
"Failed to get storage by ID",
"storageId",
backup.StorageID,
"error",
err,
)
continue
}
encryptor := encryption.GetFieldEncryptor()
err = storage.DeleteFile(encryptor, backup.ID)
if err != nil {
s.logger.Error("Failed to delete backup file", "backupId", backup.ID, "error", err)
}
if err := s.backupRepository.DeleteByID(backup.ID); err != nil {
s.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
continue
}
s.logger.Info(
"Deleted old backup",
"backupId",
backup.ID,
"databaseId",
backupConfig.DatabaseID,
)
}
}
return nil
}
func (s *BackupsScheduler) runPendingBackups() error {
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
@@ -398,6 +327,49 @@ func (s *BackupsScheduler) runPendingBackups() error {
return nil
}
func (s *BackupsScheduler) failBackupsInProgress() error {
backupsInProgress, err := s.backupRepository.FindByStatus(backups_core.BackupStatusInProgress)
if err != nil {
return err
}
for _, backup := range backupsInProgress {
if err := s.taskCancelManager.CancelTask(backup.ID); err != nil {
s.logger.Error(
"Failed to cancel backup via task cancel manager",
"backupId",
backup.ID,
"error",
err,
)
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(backup.DatabaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
continue
}
failMessage := "Backup failed due to application restart"
backup.FailMessage = &failMessage
backup.Status = backups_core.BackupStatusFailed
backup.BackupSizeMb = 0
s.backuperNode.SendBackupNotification(
backupConfig,
backup,
backups_config.NotificationBackupFailed,
&failMessage,
)
if err := s.backupRepository.Save(backup); err != nil {
return err
}
}
return nil
}
func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
nodes, err := s.backupNodesRegistry.GetAvailableNodes()
if err != nil {

View File

@@ -835,7 +835,8 @@ func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T
cache_utils.ClearAllCache()
// Start scheduler so it can handle task completions
schedulerCancel := StartSchedulerForTest(t)
scheduler := CreateTestScheduler()
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
backuperNode := CreateTestBackuperNode()
@@ -891,7 +892,7 @@ func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T
t.Logf("Initial active tasks: %d", initialActiveTasks)
// Start backup
GetBackupsScheduler().StartBackup(database.ID, false)
scheduler.StartBackup(database.ID, false)
// Wait for backup to complete
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
@@ -930,7 +931,8 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
cache_utils.ClearAllCache()
// Start scheduler so it can handle task completions
schedulerCancel := StartSchedulerForTest(t)
scheduler := CreateTestScheduler()
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
backuperNode := CreateTestBackuperNode()
@@ -993,7 +995,7 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
t.Logf("Initial active tasks: %d", initialActiveTasks)
// Start backup
GetBackupsScheduler().StartBackup(database.ID, false)
scheduler.StartBackup(database.ID, false)
// Wait for backup to fail
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
@@ -1031,3 +1033,289 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
time.Sleep(200 * time.Millisecond)
}
func Test_StartBackup_WhenBackupAlreadyInProgress_SkipsNewBackup(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Create an in-progress backup manually
inProgressBackup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
BackupSizeMb: 0,
CreatedAt: time.Now().UTC(),
}
err = backupRepository.Save(inProgressBackup)
assert.NoError(t, err)
// Try to start a new backup - should be skipped
GetBackupsScheduler().StartBackup(database.ID, false)
time.Sleep(200 * time.Millisecond)
// Verify only 1 backup exists (the original in-progress one)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status)
assert.Equal(t, inProgressBackup.ID, backups[0].ID)
time.Sleep(200 * time.Millisecond)
}
func Test_RunPendingBackups_WhenLastBackupFailedWithIsSkipRetry_SkipsBackupEvenWithRetriesEnabled(
t *testing.T,
) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Enable backups with retries enabled and high retry count
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
backupConfig.IsRetryIfFailed = true
backupConfig.MaxFailedTriesCount = 5
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Create a failed backup with IsSkipRetry set to true
failMessage := "backup failed due to size limit exceeded"
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusFailed,
FailMessage: &failMessage,
IsSkipRetry: true,
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
})
// Verify GetRemainedBackupTryCount returns 0 even though retries are enabled
lastBackup, err := backupRepository.FindLastByDatabaseID(database.ID)
assert.NoError(t, err)
assert.NotNil(t, lastBackup)
remainedTries := GetBackupsScheduler().GetRemainedBackupTryCount(lastBackup)
assert.Equal(t, 0, remainedTries, "Should return 0 tries when IsSkipRetry is true")
// Run the scheduler
GetBackupsScheduler().runPendingBackups()
time.Sleep(100 * time.Millisecond)
// Verify no new backup was created (still only 1 backup exists)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1, "No retry should be attempted when IsSkipRetry is true")
time.Sleep(200 * time.Millisecond)
}
func Test_StartBackup_When2BackupsStartedForDifferentDatabases_BothUseCasesAreCalled(t *testing.T) {
cache_utils.ClearAllCache()
// Create mock tracking use case
mockUseCase := NewMockTrackingBackupUsecase()
// Create BackuperNode with mock use case
backuperNode := CreateTestBackuperNodeWithUseCase(mockUseCase)
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
// Create scheduler
scheduler := CreateTestScheduler()
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
// Setup test data
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
// Create 2 separate databases
database1 := databases.CreateTestDatabase(workspace.ID, storage, notifier)
database2 := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
// Cleanup backups for database1
backups1, _ := backupRepository.FindByDatabaseID(database1.ID)
for _, backup := range backups1 {
backupRepository.DeleteByID(backup.ID)
}
// Cleanup backups for database2
backups2, _ := backupRepository.FindByDatabaseID(database2.ID)
for _, backup := range backups2 {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database1)
databases.RemoveTestDatabase(database2)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Enable backups for database1
backupConfig1, err := backups_config.GetBackupConfigService().
GetBackupConfigByDbId(database1.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig1.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig1.IsBackupsEnabled = true
backupConfig1.StorePeriod = period.PeriodWeek
backupConfig1.Storage = storage
backupConfig1.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig1)
assert.NoError(t, err)
// Enable backups for database2
backupConfig2, err := backups_config.GetBackupConfigService().
GetBackupConfigByDbId(database2.ID)
assert.NoError(t, err)
backupConfig2.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig2.IsBackupsEnabled = true
backupConfig2.StorePeriod = period.PeriodWeek
backupConfig2.Storage = storage
backupConfig2.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig2)
assert.NoError(t, err)
// Start 2 backups simultaneously
t.Log("Starting backup for database1")
scheduler.StartBackup(database1.ID, false)
t.Log("Starting backup for database2")
scheduler.StartBackup(database2.ID, false)
// Wait up to 10 seconds for both backups to complete
t.Log("Waiting for both backups to complete...")
success := assert.Eventually(t, func() bool {
callCount := mockUseCase.GetCallCount()
t.Logf("Current call count: %d/2", callCount)
return callCount == 2
}, 10*time.Second, 200*time.Millisecond, "Both use cases should be called within 10 seconds")
if !success {
t.Logf("Test failed: Only %d out of 2 use cases were called", mockUseCase.GetCallCount())
}
// Verify both backup IDs were received
calledBackupIDs := mockUseCase.GetCalledBackupIDs()
t.Logf("Called backup IDs: %v", calledBackupIDs)
assert.Len(t, calledBackupIDs, 2, "Both backup IDs should be tracked")
// Verify both backups exist in repository and are completed
backups1, err := backupRepository.FindByDatabaseID(database1.ID)
assert.NoError(t, err)
assert.Len(t, backups1, 1, "Database1 should have 1 backup")
if len(backups1) > 0 {
t.Logf("Database1 backup status: %s", backups1[0].Status)
}
backups2, err := backupRepository.FindByDatabaseID(database2.ID)
assert.NoError(t, err)
assert.Len(t, backups2, 1, "Database2 should have 1 backup")
if len(backups2) > 0 {
t.Logf("Database2 backup status: %s", backups2[0].Status)
}
// Verify both backups completed successfully
if len(backups1) > 0 {
assert.Equal(t, backups_core.BackupStatusCompleted, backups1[0].Status,
"Database1 backup should be completed")
}
if len(backups2) > 0 {
assert.Equal(t, backups_core.BackupStatusCompleted, backups2[0].Status,
"Database2 backup should be completed")
}
time.Sleep(200 * time.Millisecond)
}

View File

@@ -3,6 +3,8 @@ package backuping
import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
@@ -35,19 +37,56 @@ func CreateTestRouter() *gin.Engine {
func CreateTestBackuperNode() *BackuperNode {
return &BackuperNode{
databases.GetDatabaseService(),
encryption.GetFieldEncryptor(),
workspaces_services.GetWorkspaceService(),
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
notifiers.GetNotifierService(),
taskCancelManager,
backupNodesRegistry,
logger.GetLogger(),
usecases.GetCreateBackupUsecase(),
uuid.New(),
time.Time{},
databaseService: databases.GetDatabaseService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
workspaceService: workspaces_services.GetWorkspaceService(),
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
notificationSender: notifiers.GetNotifierService(),
backupCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
logger: logger.GetLogger(),
createBackupUseCase: usecases.GetCreateBackupUsecase(),
nodeID: uuid.New(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}
func CreateTestBackuperNodeWithUseCase(useCase backups_core.CreateBackupUsecase) *BackuperNode {
return &BackuperNode{
databaseService: databases.GetDatabaseService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
workspaceService: workspaces_services.GetWorkspaceService(),
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
notificationSender: notifiers.GetNotifierService(),
backupCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
logger: logger.GetLogger(),
createBackupUseCase: useCase,
nodeID: uuid.New(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}
func CreateTestScheduler() *BackupsScheduler {
return &BackupsScheduler{
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
taskCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
lastBackupTime: time.Now().UTC(),
logger: logger.GetLogger(),
backupToNodeRelations: make(map[uuid.UUID]BackupToNodeRelation),
backuperNode: CreateTestBackuperNode(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}
@@ -141,13 +180,13 @@ func StartBackuperNodeForTest(t *testing.T, backuperNode *BackuperNode) context.
// StartSchedulerForTest starts the BackupsScheduler in a goroutine for testing.
// The scheduler subscribes to task completions and manages backup lifecycle.
// Returns a context cancel function that should be deferred to stop the scheduler.
func StartSchedulerForTest(t *testing.T) context.CancelFunc {
func StartSchedulerForTest(t *testing.T, scheduler *BackupsScheduler) context.CancelFunc {
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
GetBackupsScheduler().Run(ctx)
scheduler.Run(ctx)
close(done)
}()

View File

@@ -80,7 +80,7 @@ func Test_GetBackups_PermissionsEnforced(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, _ := createTestDatabaseWithBackups(workspace, owner, router)
database, _, storage := createTestDatabaseWithBackups(workspace, owner, router)
var testUserToken string
if tt.isGlobalAdmin {
@@ -122,6 +122,12 @@ func Test_GetBackups_PermissionsEnforced(t *testing.T) {
} else {
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
@@ -218,6 +224,10 @@ func Test_CreateBackup_PermissionsEnforced(t *testing.T) {
} else {
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
// Cleanup
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
@@ -261,6 +271,10 @@ func Test_CreateBackup_AuditLogWritten(t *testing.T) {
}
}
assert.True(t, found, "Audit log for backup creation not found")
// Cleanup
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_DeleteBackup_PermissionsEnforced(t *testing.T) {
@@ -314,7 +328,7 @@ func Test_DeleteBackup_PermissionsEnforced(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
var testUserToken string
if tt.isGlobalAdmin {
@@ -358,6 +372,12 @@ func Test_DeleteBackup_PermissionsEnforced(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, 0, len(response.Backups))
}
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
@@ -367,7 +387,7 @@ func Test_DeleteBackup_AuditLogWritten(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
test_utils.MakeDeleteRequest(
t,
@@ -398,6 +418,12 @@ func Test_DeleteBackup_AuditLogWritten(t *testing.T) {
}
}
assert.True(t, found, "Audit log for backup deletion not found")
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_GenerateDownloadToken_PermissionsEnforced(t *testing.T) {
@@ -444,7 +470,7 @@ func Test_GenerateDownloadToken_PermissionsEnforced(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
var testUserToken string
if tt.isGlobalAdmin {
@@ -488,6 +514,12 @@ func Test_GenerateDownloadToken_PermissionsEnforced(t *testing.T) {
} else {
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
@@ -497,7 +529,7 @@ func Test_DownloadBackup_WithValidToken_Success(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
// Generate download token
var tokenResponse backups_download.GenerateDownloadTokenResponse
@@ -524,6 +556,12 @@ func Test_DownloadBackup_WithValidToken_Success(t *testing.T) {
contentDisposition := testResp.Headers.Get("Content-Disposition")
assert.Contains(t, contentDisposition, "attachment")
assert.Contains(t, contentDisposition, tokenResponse.Filename)
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_DownloadBackup_WithoutToken_Unauthorized(t *testing.T) {
@@ -531,7 +569,7 @@ func Test_DownloadBackup_WithoutToken_Unauthorized(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
// Try to download without token
testResp := test_utils.MakeGetRequest(
@@ -543,6 +581,12 @@ func Test_DownloadBackup_WithoutToken_Unauthorized(t *testing.T) {
)
assert.Contains(t, string(testResp.Body), "download token is required")
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_DownloadBackup_WithInvalidToken_Unauthorized(t *testing.T) {
@@ -550,7 +594,7 @@ func Test_DownloadBackup_WithInvalidToken_Unauthorized(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
// Try to download with invalid token
testResp := test_utils.MakeGetRequest(
@@ -562,6 +606,12 @@ func Test_DownloadBackup_WithInvalidToken_Unauthorized(t *testing.T) {
)
assert.Contains(t, string(testResp.Body), "invalid or expired download token")
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_DownloadBackup_WithExpiredToken_Unauthorized(t *testing.T) {
@@ -569,7 +619,7 @@ func Test_DownloadBackup_WithExpiredToken_Unauthorized(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
// Get user for token generation
userService := users_services.GetUserService()
@@ -611,6 +661,12 @@ func Test_DownloadBackup_WithExpiredToken_Unauthorized(t *testing.T) {
}
}
assert.False(t, found, "Audit log should NOT be created for failed download with expired token")
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_DownloadBackup_TokenUsedOnce_CannotReuseToken(t *testing.T) {
@@ -618,7 +674,7 @@ func Test_DownloadBackup_TokenUsedOnce_CannotReuseToken(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
// Generate download token
var tokenResponse backups_download.GenerateDownloadTokenResponse
@@ -651,6 +707,12 @@ func Test_DownloadBackup_TokenUsedOnce_CannotReuseToken(t *testing.T) {
)
assert.Contains(t, string(testResp.Body), "invalid or expired download token")
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_DownloadBackup_WithDifferentBackupToken_Unauthorized(t *testing.T) {
@@ -705,6 +767,13 @@ func Test_DownloadBackup_WithDifferentBackupToken_Unauthorized(t *testing.T) {
)
assert.Contains(t, string(testResp.Body), "invalid or expired download token")
// Cleanup
databases.RemoveTestDatabase(database1)
databases.RemoveTestDatabase(database2)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_DownloadBackup_AuditLogWritten(t *testing.T) {
@@ -712,7 +781,7 @@ func Test_DownloadBackup_AuditLogWritten(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
// Generate download token
var tokenResponse backups_download.GenerateDownloadTokenResponse
@@ -756,6 +825,12 @@ func Test_DownloadBackup_AuditLogWritten(t *testing.T) {
}
}
assert.True(t, found, "Audit log for backup download not found")
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_DownloadBackup_ProperFilenameForPostgreSQL(t *testing.T) {
@@ -856,6 +931,12 @@ func Test_DownloadBackup_ProperFilenameForPostgreSQL(t *testing.T) {
contentDisposition,
"Filename should contain timestamp",
)
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
@@ -948,6 +1029,12 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
}
}
assert.True(t, foundCancelLog, "Cancel audit log should be created")
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_ConcurrentDownloadPrevention(t *testing.T) {
@@ -955,7 +1042,7 @@ func Test_ConcurrentDownloadPrevention(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
var token1Response backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
@@ -1003,6 +1090,12 @@ func Test_ConcurrentDownloadPrevention(t *testing.T) {
if !service.IsDownloadInProgress(owner.UserID) {
t.Log("Warning: First download completed before we could test concurrency")
<-downloadComplete
// Cleanup before early return
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
return
}
@@ -1049,6 +1142,12 @@ func Test_ConcurrentDownloadPrevention(t *testing.T) {
t.Log(
"Successfully prevented concurrent downloads and allowed subsequent downloads after completion",
)
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_GenerateDownloadToken_BlockedWhenDownloadInProgress(t *testing.T) {
@@ -1056,7 +1155,7 @@ func Test_GenerateDownloadToken_BlockedWhenDownloadInProgress(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
var token1Response backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
@@ -1092,6 +1191,12 @@ func Test_GenerateDownloadToken_BlockedWhenDownloadInProgress(t *testing.T) {
if !service.IsDownloadInProgress(owner.UserID) {
t.Log("Warning: First download completed before we could test token generation blocking")
<-downloadComplete
// Cleanup before early return
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
return
}
@@ -1131,6 +1236,12 @@ func Test_GenerateDownloadToken_BlockedWhenDownloadInProgress(t *testing.T) {
t.Log(
"Successfully blocked token generation during download and allowed generation after completion",
)
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func createTestRouter() *gin.Engine {
@@ -1156,7 +1267,7 @@ func createTestDatabase(
Type: databases.DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",
@@ -1222,7 +1333,7 @@ func createTestDatabaseWithBackups(
workspace *workspaces_models.Workspace,
owner *users_dto.SignInResponseDTO,
router *gin.Engine,
) (*databases.Database, *backups_core.Backup) {
) (*databases.Database, *backups_core.Backup, *storages.Storage) {
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
@@ -1242,7 +1353,7 @@ func createTestDatabaseWithBackups(
backup := createTestBackup(database, owner)
return database, backup
return database, backup, storage
}
func createTestBackup(
@@ -1320,7 +1431,7 @@ func Test_BandwidthThrottling_SingleDownload_Uses75Percent(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
bandwidthManager := backups_download.GetBandwidthManager()
initialCount := bandwidthManager.GetActiveDownloadCount()
@@ -1370,6 +1481,12 @@ func Test_BandwidthThrottling_SingleDownload_Uses75Percent(t *testing.T) {
time.Sleep(50 * time.Millisecond)
finalCount := bandwidthManager.GetActiveDownloadCount()
assert.Equal(t, initialCount, finalCount, "Download should be unregistered after completion")
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_BandwidthThrottling_MultipleDownloads_ShareBandwidth(t *testing.T) {
@@ -1489,6 +1606,12 @@ func Test_BandwidthThrottling_MultipleDownloads_ShareBandwidth(t *testing.T) {
time.Sleep(100 * time.Millisecond)
finalCount := bandwidthManager.GetActiveDownloadCount()
assert.Equal(t, initialCount, finalCount, "All downloads should be unregistered")
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_BandwidthThrottling_DynamicAdjustment(t *testing.T) {
@@ -1577,4 +1700,10 @@ func Test_BandwidthThrottling_DynamicAdjustment(t *testing.T) {
time.Sleep(100 * time.Millisecond)
finalCount := bandwidthManager.GetActiveDownloadCount()
assert.Equal(t, initialCount, finalCount, "All downloads completed and unregistered")
// Cleanup
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}

View File

@@ -15,6 +15,7 @@ type Backup struct {
Status BackupStatus `json:"status" gorm:"column:status;not null"`
FailMessage *string `json:"failMessage" gorm:"column:fail_message"`
IsSkipRetry bool `json:"isSkipRetry" gorm:"column:is_skip_retry;type:boolean;not null"`
BackupSizeMb float64 `json:"backupSizeMb" gorm:"column:backup_size_mb;default:0"`

View File

@@ -212,3 +212,36 @@ func (r *BackupRepository) CountByDatabaseID(databaseID uuid.UUID) (int64, error
return count, nil
}
func (r *BackupRepository) GetTotalSizeByDatabase(databaseID uuid.UUID) (float64, error) {
var totalSize float64
if err := storage.
GetDb().
Model(&Backup{}).
Select("COALESCE(SUM(backup_size_mb), 0)").
Where("database_id = ? AND status != ?", databaseID, BackupStatusInProgress).
Scan(&totalSize).Error; err != nil {
return 0, err
}
return totalSize, nil
}
func (r *BackupRepository) FindOldestByDatabaseExcludingInProgress(
databaseID uuid.UUID,
limit int,
) ([]*Backup, error) {
var backups []*Backup
if err := storage.
GetDb().
Where("database_id = ? AND status != ?", databaseID, BackupStatusInProgress).
Order("created_at ASC").
Limit(limit).
Find(&backups).Error; err != nil {
return nil, err
}
return backups, nil
}

View File

@@ -1,6 +1,9 @@
package backups
import (
"sync"
"sync/atomic"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups/backuping"
backups_core "databasus-backend/internal/features/backups/backups/core"
@@ -22,22 +25,23 @@ var backupRepository = &backups_core.BackupRepository{}
var taskCancelManager = task_cancellation.GetTaskCancelManager()
var backupService = &BackupService{
databaseService: databases.GetDatabaseService(),
storageService: storages.GetStorageService(),
backupRepository: backupRepository,
notifierService: notifiers.GetNotifierService(),
notificationSender: notifiers.GetNotifierService(),
backupConfigService: backups_config.GetBackupConfigService(),
secretKeyService: encryption_secrets.GetSecretKeyService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
createBackupUseCase: usecases.GetCreateBackupUsecase(),
logger: logger.GetLogger(),
backupRemoveListeners: []backups_core.BackupRemoveListener{},
workspaceService: workspaces_services.GetWorkspaceService(),
auditLogService: audit_logs.GetAuditLogService(),
taskCancelManager: taskCancelManager,
downloadTokenService: backups_download.GetDownloadTokenService(),
backupSchedulerService: backuping.GetBackupsScheduler(),
databases.GetDatabaseService(),
storages.GetStorageService(),
backupRepository,
notifiers.GetNotifierService(),
notifiers.GetNotifierService(),
backups_config.GetBackupConfigService(),
encryption_secrets.GetSecretKeyService(),
encryption.GetFieldEncryptor(),
usecases.GetCreateBackupUsecase(),
logger.GetLogger(),
[]backups_core.BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
taskCancelManager,
backups_download.GetDownloadTokenService(),
backuping.GetBackupsScheduler(),
backuping.GetBackupCleaner(),
}
var backupController = &BackupController{
@@ -52,11 +56,26 @@ func GetBackupController() *BackupController {
return backupController
}
func SetupDependencies() {
backups_config.
GetBackupConfigService().
SetDatabaseStorageChangeListener(backupService)
var (
setupOnce sync.Once
isSetup atomic.Bool
)
databases.GetDatabaseService().AddDbRemoveListener(backupService)
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
func SetupDependencies() {
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
backups_config.
GetBackupConfigService().
SetDatabaseStorageChangeListener(backupService)
databases.GetDatabaseService().AddDbRemoveListener(backupService)
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -2,33 +2,49 @@ package backups_download
import (
"context"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
)
type DownloadTokenBackgroundService struct {
downloadTokenService *DownloadTokenService
logger *slog.Logger
runOnce sync.Once
hasRun atomic.Bool
}
func (s *DownloadTokenBackgroundService) Run(ctx context.Context) {
s.logger.Info("Starting download token cleanup background service")
wasAlreadyRun := s.hasRun.Load()
if ctx.Err() != nil {
return
}
s.runOnce.Do(func() {
s.hasRun.Store(true)
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
s.logger.Info("Starting download token cleanup background service")
for {
select {
case <-ctx.Done():
if ctx.Err() != nil {
return
case <-ticker.C:
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
s.logger.Error("Failed to clean expired download tokens", "error", err)
}
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
s.logger.Error("Failed to clean expired download tokens", "error", err)
}
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
}

View File

@@ -1,6 +1,9 @@
package backups_download
import (
"sync"
"sync/atomic"
"databasus-backend/internal/config"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/logger"
@@ -30,8 +33,10 @@ func init() {
}
downloadTokenBackgroundService = &DownloadTokenBackgroundService{
downloadTokenService,
logger.GetLogger(),
downloadTokenService: downloadTokenService,
logger: logger.GetLogger(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}

View File

@@ -46,6 +46,7 @@ type BackupService struct {
taskCancelManager *task_cancellation.TaskCancelManager
downloadTokenService *backups_download.DownloadTokenService
backupSchedulerService *backuping.BackupsScheduler
backupCleaner *backuping.BackupCleaner
}
func (s *BackupService) AddBackupRemoveListener(listener backups_core.BackupRemoveListener) {
@@ -189,7 +190,7 @@ func (s *BackupService) DeleteBackup(
database.WorkspaceID,
)
return s.deleteBackup(backup)
return s.backupCleaner.DeleteBackup(backup)
}
func (s *BackupService) GetBackup(backupID uuid.UUID) (*backups_core.Backup, error) {
@@ -292,29 +293,6 @@ func (s *BackupService) GetBackupFile(
return reader, backup, database, nil
}
func (s *BackupService) deleteBackup(backup *backups_core.Backup) error {
for _, listener := range s.backupRemoveListeners {
if err := listener.OnBeforeBackupRemove(backup); err != nil {
return err
}
}
storage, err := s.storageService.GetStorageByID(backup.StorageID)
if err != nil {
return err
}
err = storage.DeleteFile(s.fieldEncryptor, backup.ID)
if err != nil {
// we do not return error here, because sometimes clean up performed
// before unavailable storage removal or change - therefore we should
// proceed even in case of error
s.logger.Error("Failed to delete backup file", "error", err)
}
return s.backupRepository.DeleteByID(backup.ID)
}
func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
dbBackupsInProgress, err := s.backupRepository.FindByDatabaseIdAndStatus(
databaseID,
@@ -336,7 +314,7 @@ func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
}
for _, dbBackup := range dbBackups {
err := s.deleteBackup(dbBackup)
err := s.backupCleaner.DeleteBackup(dbBackup)
if err != nil {
return err
}

View File

@@ -16,6 +16,7 @@ type BackupConfigController struct {
func (c *BackupConfigController) RegisterRoutes(router *gin.RouterGroup) {
router.POST("/backup-configs/save", c.SaveBackupConfig)
router.GET("/backup-configs/database/:id/plan", c.GetDatabasePlan)
router.GET("/backup-configs/database/:id", c.GetBackupConfigByDbID)
router.GET("/backup-configs/storage/:id/is-using", c.IsStorageUsing)
router.GET("/backup-configs/storage/:id/databases-count", c.CountDatabasesForStorage)
@@ -92,6 +93,39 @@ func (c *BackupConfigController) GetBackupConfigByDbID(ctx *gin.Context) {
ctx.JSON(http.StatusOK, backupConfig)
}
// GetDatabasePlan
// @Summary Get database plan by database ID
// @Description Get the plan limits for a specific database (max backup size, max total size, max storage period)
// @Tags backup-configs
// @Produce json
// @Param id path string true "Database ID"
// @Success 200 {object} plans.DatabasePlan
// @Failure 400 {object} map[string]string "Invalid database ID"
// @Failure 401 {object} map[string]string "User not authenticated"
// @Failure 404 {object} map[string]string "Database not found or access denied"
// @Router /backup-configs/database/{id}/plan [get]
func (c *BackupConfigController) GetDatabasePlan(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
return
}
plan, err := c.backupConfigService.GetDatabasePlan(user, id)
if err != nil {
ctx.JSON(http.StatusNotFound, gin.H{"error": "database plan not found"})
return
}
ctx.JSON(http.StatusOK, plan)
}
// IsStorageUsing
// @Summary Check if storage is being used
// @Description Check if a storage is currently being used by any backup configuration

View File

@@ -6,6 +6,7 @@ import (
"net/http"
"strconv"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -16,11 +17,14 @@ import (
"databasus-backend/internal/features/databases/databases/postgresql"
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/notifiers"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/features/storages"
local_storage "databasus-backend/internal/features/storages/models/local"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/storage"
"databasus-backend/internal/util/period"
test_utils "databasus-backend/internal/util/testing"
"databasus-backend/internal/util/tools"
@@ -89,6 +93,11 @@ func Test_SaveBackupConfig_PermissionsEnforced(t *testing.T) {
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
@@ -152,6 +161,11 @@ func Test_SaveBackupConfig_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *test
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
timeOfDay := "04:00"
@@ -242,6 +256,11 @@ func Test_GetBackupConfigByDbID_PermissionsEnforced(t *testing.T) {
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
@@ -290,6 +309,11 @@ func Test_GetBackupConfigByDbID_ReturnsDefaultConfigForNewDatabase(t *testing.T)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
var response BackupConfig
test_utils.MakeGetRequestAndUnmarshal(
t,
@@ -300,14 +324,214 @@ func Test_GetBackupConfigByDbID_ReturnsDefaultConfigForNewDatabase(t *testing.T)
&response,
)
var plan plans.DatabasePlan
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
"Bearer "+owner.Token,
http.StatusOK,
&plan,
)
assert.Equal(t, database.ID, response.DatabaseID)
assert.False(t, response.IsBackupsEnabled)
assert.Equal(t, period.PeriodWeek, response.StorePeriod)
assert.Equal(t, plan.MaxStoragePeriod, response.StorePeriod)
assert.Equal(t, plan.MaxBackupSizeMB, response.MaxBackupSizeMB)
assert.Equal(t, plan.MaxBackupsTotalSizeMB, response.MaxBackupsTotalSizeMB)
assert.True(t, response.IsRetryIfFailed)
assert.Equal(t, 3, response.MaxFailedTriesCount)
assert.NotNil(t, response.BackupInterval)
}
func Test_GetDatabasePlan_ForNewDatabase_PlanAlwaysReturned(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
var response plans.DatabasePlan
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
"Bearer "+owner.Token,
http.StatusOK,
&response,
)
assert.Equal(t, database.ID, response.DatabaseID)
assert.NotNil(t, response.MaxBackupSizeMB)
assert.NotNil(t, response.MaxBackupsTotalSizeMB)
assert.NotEmpty(t, response.MaxStoragePeriod)
}
func Test_SaveBackupConfig_WhenPlanLimitsAreAdjusted_ValidationEnforced(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Get plan via API (triggers auto-creation)
var plan plans.DatabasePlan
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
"Bearer "+owner.Token,
http.StatusOK,
&plan,
)
assert.Equal(t, database.ID, plan.DatabaseID)
// Adjust plan limits directly in database to fixed restrictive values
err := storage.GetDb().Model(&plans.DatabasePlan{}).
Where("database_id = ?", database.ID).
Updates(map[string]any{
"max_backup_size_mb": 100,
"max_backups_total_size_mb": 1000,
"max_storage_period": period.PeriodMonth,
}).Error
assert.NoError(t, err)
// Test 1: Try to save backup config with exceeded backup size limit
timeOfDay := "04:00"
backupConfigExceededSize := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 200, // Exceeds limit of 100
MaxBackupsTotalSizeMB: 800,
}
respExceededSize := test_utils.MakePostRequest(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
backupConfigExceededSize,
http.StatusBadRequest,
)
assert.Contains(t, string(respExceededSize.Body), "max backup size exceeds plan limit")
// Test 2: Try to save backup config with exceeded total size limit
backupConfigExceededTotal := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 50,
MaxBackupsTotalSizeMB: 2000, // Exceeds limit of 1000
}
respExceededTotal := test_utils.MakePostRequest(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
backupConfigExceededTotal,
http.StatusBadRequest,
)
assert.Contains(t, string(respExceededTotal.Body), "max total backups size exceeds plan limit")
// Test 3: Try to save backup config with exceeded storage period limit
backupConfigExceededPeriod := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodYear, // Exceeds limit of Month
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 80,
MaxBackupsTotalSizeMB: 800,
}
respExceededPeriod := test_utils.MakePostRequest(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
backupConfigExceededPeriod,
http.StatusBadRequest,
)
assert.Contains(t, string(respExceededPeriod.Body), "storage period exceeds plan limit")
// Test 4: Save backup config within all limits - should succeed
backupConfigValid := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek, // Within Month limit
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 80, // Within 100 limit
MaxBackupsTotalSizeMB: 800, // Within 1000 limit
}
var responseValid BackupConfig
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
backupConfigValid,
http.StatusOK,
&responseValid,
)
assert.Equal(t, database.ID, responseValid.DatabaseID)
assert.Equal(t, int64(80), responseValid.MaxBackupSizeMB)
assert.Equal(t, int64(800), responseValid.MaxBackupsTotalSizeMB)
assert.Equal(t, period.PeriodWeek, responseValid.StorePeriod)
}
func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
@@ -340,6 +564,10 @@ func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
)
storage := createTestStorage(workspace.ID)
defer func() {
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
var testUserToken string
if tt.isStorageOwner {
testUserToken = storageOwner.Token
@@ -372,10 +600,6 @@ func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
)
assert.Contains(t, string(testResp.Body), "error")
}
// Cleanup
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
@@ -387,6 +611,11 @@ func Test_SaveBackupConfig_WithEncryptionNone_ConfigSaved(t *testing.T) {
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
@@ -426,6 +655,11 @@ func Test_SaveBackupConfig_WithEncryptionEncrypted_ConfigSaved(t *testing.T) {
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
@@ -536,6 +770,15 @@ func Test_TransferDatabase_PermissionsEnforced(t *testing.T) {
targetStorage := createTestStorage(targetWorkspace.ID)
defer func() {
// Cleanup in correct order to avoid foreign key violations
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond) // Wait for cascade delete of backup_config
storages.RemoveTestStorage(targetStorage.ID)
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
@@ -628,6 +871,12 @@ func Test_TransferDatabase_NonMemberInSourceWorkspace_CannotTransfer(t *testing.
router,
)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
request := TransferDatabaseRequest{
TargetWorkspaceID: targetWorkspace.ID,
}
@@ -668,6 +917,12 @@ func Test_TransferDatabase_NonMemberInTargetWorkspace_CannotTransfer(t *testing.
router,
)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
request := TransferDatabaseRequest{
TargetWorkspaceID: targetWorkspace.ID,
}
@@ -695,6 +950,13 @@ func Test_TransferDatabase_ToNewStorage_DatabaseTransferd(t *testing.T) {
sourceStorage := createTestStorage(sourceWorkspace.ID)
targetStorage := createTestStorage(targetWorkspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
time.Sleep(200 * time.Millisecond) // Wait for cascading deletes
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
timeOfDay := "04:00"
backupConfigRequest := BackupConfig{
DatabaseID: database.ID,
@@ -774,6 +1036,13 @@ func Test_TransferDatabase_WithExistingStorage_DatabaseAndStorageTransferd(t *te
database := createTestDatabaseViaAPI("Test Database", sourceWorkspace.ID, owner.Token, router)
storage := createTestStorage(sourceWorkspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
time.Sleep(200 * time.Millisecond) // Wait for cascading deletes
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
timeOfDay := "04:00"
backupConfigRequest := BackupConfig{
DatabaseID: database.ID,
@@ -863,6 +1132,14 @@ func Test_TransferDatabase_StorageHasOtherDBs_CannotTransfer(t *testing.T) {
)
storage := createTestStorage(sourceWorkspace.ID)
defer func() {
databases.RemoveTestDatabase(database1)
databases.RemoveTestDatabase(database2)
time.Sleep(200 * time.Millisecond) // Wait for cascading deletes
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
timeOfDay := "04:00"
backupConfigRequest1 := BackupConfig{
DatabaseID: database1.ID,
@@ -945,6 +1222,14 @@ func Test_TransferDatabase_WithNotifiers_NotifiersTransferred(t *testing.T) {
targetStorage := createTestStorage(targetWorkspace.ID)
notifier := notifiers.CreateTestNotifier(sourceWorkspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
time.Sleep(200 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
database.Notifiers = []notifiers.Notifier{*notifier}
var updatedDatabase databases.Database
test_utils.MakePostRequestAndUnmarshal(
@@ -1048,6 +1333,15 @@ func Test_TransferDatabase_NotifierHasOtherDBs_NotifierSkipped(t *testing.T) {
targetStorage := createTestStorage(targetWorkspace.ID)
sharedNotifier := notifiers.CreateTestNotifier(sourceWorkspace.ID)
defer func() {
databases.RemoveTestDatabase(database1)
databases.RemoveTestDatabase(database2)
time.Sleep(200 * time.Millisecond)
notifiers.RemoveTestNotifier(sharedNotifier)
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
database1.Notifiers = []notifiers.Notifier{*sharedNotifier}
test_utils.MakePostRequest(
t,
@@ -1160,6 +1454,16 @@ func Test_TransferDatabase_WithMultipleNotifiers_OnlyExclusiveOnesTransferred(t
exclusiveNotifier := notifiers.CreateTestNotifier(sourceWorkspace.ID)
sharedNotifier := notifiers.CreateTestNotifier(sourceWorkspace.ID)
defer func() {
databases.RemoveTestDatabase(database1)
databases.RemoveTestDatabase(database2)
time.Sleep(200 * time.Millisecond)
notifiers.RemoveTestNotifier(exclusiveNotifier)
notifiers.RemoveTestNotifier(sharedNotifier)
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
database1.Notifiers = []notifiers.Notifier{*exclusiveNotifier, *sharedNotifier}
test_utils.MakePostRequest(
t,
@@ -1271,6 +1575,14 @@ func Test_TransferDatabase_WithTargetNotifiers_NotifiersAssigned(t *testing.T) {
targetStorage := createTestStorage(targetWorkspace.ID)
targetNotifier := notifiers.CreateTestNotifier(targetWorkspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
time.Sleep(200 * time.Millisecond)
notifiers.RemoveTestNotifier(targetNotifier)
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
timeOfDay := "04:00"
backupConfigRequest := BackupConfig{
DatabaseID: database.ID,
@@ -1342,6 +1654,15 @@ func Test_TransferDatabase_TargetNotifierFromDifferentWorkspace_ReturnsBadReques
targetStorage := createTestStorage(targetWorkspace.ID)
wrongNotifier := notifiers.CreateTestNotifier(otherWorkspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
time.Sleep(200 * time.Millisecond)
notifiers.RemoveTestNotifier(wrongNotifier)
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
workspaces_testing.RemoveTestWorkspace(otherWorkspace, router)
}()
timeOfDay := "04:00"
backupConfigRequest := BackupConfig{
DatabaseID: database.ID,
@@ -1399,6 +1720,14 @@ func Test_TransferDatabase_TargetStorageFromDifferentWorkspace_ReturnsBadRequest
sourceStorage := createTestStorage(sourceWorkspace.ID)
wrongStorage := createTestStorage(otherWorkspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
time.Sleep(200 * time.Millisecond)
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
workspaces_testing.RemoveTestWorkspace(otherWorkspace, router)
}()
timeOfDay := "04:00"
backupConfigRequest := BackupConfig{
DatabaseID: database.ID,
@@ -1443,6 +1772,115 @@ func Test_TransferDatabase_TargetStorageFromDifferentWorkspace_ReturnsBadRequest
assert.Contains(t, string(testResp.Body), "target storage does not belong to target workspace")
}
func Test_SaveBackupConfig_WithSystemStorage_CanBeUsedByAnyDatabase(t *testing.T) {
router := createTestRouterWithStorageForTransfer()
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
workspaceA := workspaces_testing.CreateTestWorkspace("Workspace A", owner1, router)
workspaceB := workspaces_testing.CreateTestWorkspace("Workspace B", owner2, router)
databaseA := createTestDatabaseViaAPI("Database A", workspaceA.ID, owner1.Token, router)
// Test 1: Regular storage from workspace B cannot be used by database in workspace A
regularStorageB := createTestStorage(workspaceB.ID)
timeOfDay := "04:00"
backupConfigWithRegularStorage := BackupConfig{
DatabaseID: databaseA.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
StorageID: &regularStorageB.ID,
Storage: regularStorageB,
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
}
respRegular := test_utils.MakePostRequest(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner1.Token,
backupConfigWithRegularStorage,
http.StatusBadRequest,
)
assert.Contains(t, string(respRegular.Body), "storage does not belong to the same workspace")
// Test 2: System storage from workspace B CAN be used by database in workspace A
systemStorageB := &storages.Storage{
WorkspaceID: workspaceB.ID,
Type: storages.StorageTypeLocal,
Name: "Test System Storage " + uuid.New().String(),
IsSystem: true,
LocalStorage: &local_storage.LocalStorage{},
}
var savedSystemStorage storages.Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+admin.Token,
*systemStorageB,
http.StatusOK,
&savedSystemStorage,
)
assert.True(t, savedSystemStorage.IsSystem)
backupConfigWithSystemStorage := BackupConfig{
DatabaseID: databaseA.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
StorageID: &savedSystemStorage.ID,
Storage: &savedSystemStorage,
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
}
var savedConfig BackupConfig
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner1.Token,
backupConfigWithSystemStorage,
http.StatusOK,
&savedConfig,
)
assert.Equal(t, databaseA.ID, savedConfig.DatabaseID)
assert.NotNil(t, savedConfig.StorageID)
assert.Equal(t, savedSystemStorage.ID, *savedConfig.StorageID)
assert.True(t, savedConfig.IsBackupsEnabled)
// Cleanup: database first (cascades to backup_config), then storages, then workspaces
databases.RemoveTestDatabase(databaseA)
storages.RemoveTestStorage(regularStorageB.ID)
storages.RemoveTestStorage(savedSystemStorage.ID)
workspaces_testing.RemoveTestWorkspace(workspaceA, router)
workspaces_testing.RemoveTestWorkspace(workspaceB, router)
}
func createTestDatabaseViaAPI(
name string,
workspaceID uuid.UUID,
@@ -1462,7 +1900,7 @@ func createTestDatabaseViaAPI(
Type: databases.DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",

View File

@@ -1,10 +1,15 @@
package backups_config
import (
"sync"
"sync/atomic"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/features/storages"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/logger"
)
var backupConfigRepository = &BackupConfigRepository{}
@@ -14,6 +19,7 @@ var backupConfigService = &BackupConfigService{
storages.GetStorageService(),
notifiers.GetNotifierService(),
workspaces_services.GetWorkspaceService(),
plans.GetDatabasePlanService(),
nil,
}
var backupConfigController = &BackupConfigController{
@@ -28,6 +34,21 @@ func GetBackupConfigService() *BackupConfigService {
return backupConfigService
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
storages.GetStorageService().SetStorageDatabaseCounter(backupConfigService)
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
storages.GetStorageService().SetStorageDatabaseCounter(backupConfigService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -1,7 +1,9 @@
package backups_config
import (
"databasus-backend/internal/config"
"databasus-backend/internal/features/intervals"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/features/storages"
"databasus-backend/internal/util/period"
"errors"
@@ -31,6 +33,11 @@ type BackupConfig struct {
MaxFailedTriesCount int `json:"maxFailedTriesCount" gorm:"column:max_failed_tries_count;type:int;not null"`
Encryption BackupEncryption `json:"encryption" gorm:"column:encryption;type:text;not null;default:'NONE'"`
// MaxBackupSizeMB limits individual backup size. 0 = unlimited.
MaxBackupSizeMB int64 `json:"maxBackupSizeMb" gorm:"column:max_backup_size_mb;type:int;not null"`
// MaxBackupsTotalSizeMB limits total size of all backups. 0 = unlimited.
MaxBackupsTotalSizeMB int64 `json:"maxBackupsTotalSizeMb" gorm:"column:max_backups_total_size_mb;type:int;not null"`
}
func (h *BackupConfig) TableName() string {
@@ -70,7 +77,7 @@ func (b *BackupConfig) AfterFind(tx *gorm.DB) error {
return nil
}
func (b *BackupConfig) Validate() error {
func (b *BackupConfig) Validate(plan *plans.DatabasePlan) error {
// Backup interval is required either as ID or as object
if b.BackupIntervalID == uuid.Nil && b.BackupInterval == nil {
return errors.New("backup interval is required")
@@ -89,20 +96,59 @@ func (b *BackupConfig) Validate() error {
return errors.New("encryption must be NONE or ENCRYPTED")
}
if config.GetEnv().IsCloud {
if b.Encryption != BackupEncryptionEncrypted {
return errors.New("encryption is mandatory for cloud storage")
}
}
if b.MaxBackupSizeMB < 0 {
return errors.New("max backup size must be non-negative")
}
if b.MaxBackupsTotalSizeMB < 0 {
return errors.New("max backups total size must be non-negative")
}
// Validate against plan limits
// Check storage period limit
if plan.MaxStoragePeriod != period.PeriodForever {
if b.StorePeriod.CompareTo(plan.MaxStoragePeriod) > 0 {
return errors.New("storage period exceeds plan limit")
}
}
// Check max backup size limit (0 in plan means unlimited)
if plan.MaxBackupSizeMB > 0 {
if b.MaxBackupSizeMB == 0 || b.MaxBackupSizeMB > plan.MaxBackupSizeMB {
return errors.New("max backup size exceeds plan limit")
}
}
// Check max total backups size limit (0 in plan means unlimited)
if plan.MaxBackupsTotalSizeMB > 0 {
if b.MaxBackupsTotalSizeMB == 0 ||
b.MaxBackupsTotalSizeMB > plan.MaxBackupsTotalSizeMB {
return errors.New("max total backups size exceeds plan limit")
}
}
return nil
}
func (b *BackupConfig) Copy(newDatabaseID uuid.UUID) *BackupConfig {
return &BackupConfig{
DatabaseID: newDatabaseID,
IsBackupsEnabled: b.IsBackupsEnabled,
StorePeriod: b.StorePeriod,
BackupIntervalID: uuid.Nil,
BackupInterval: b.BackupInterval.Copy(),
StorageID: b.StorageID,
SendNotificationsOn: b.SendNotificationsOn,
IsRetryIfFailed: b.IsRetryIfFailed,
MaxFailedTriesCount: b.MaxFailedTriesCount,
Encryption: b.Encryption,
DatabaseID: newDatabaseID,
IsBackupsEnabled: b.IsBackupsEnabled,
StorePeriod: b.StorePeriod,
BackupIntervalID: uuid.Nil,
BackupInterval: b.BackupInterval.Copy(),
StorageID: b.StorageID,
SendNotificationsOn: b.SendNotificationsOn,
IsRetryIfFailed: b.IsRetryIfFailed,
MaxFailedTriesCount: b.MaxFailedTriesCount,
Encryption: b.Encryption,
MaxBackupSizeMB: b.MaxBackupSizeMB,
MaxBackupsTotalSizeMB: b.MaxBackupsTotalSizeMB,
}
}

View File

@@ -0,0 +1,391 @@
package backups_config
import (
"testing"
"databasus-backend/internal/features/intervals"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/util/period"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_Validate_WhenStoragePeriodIsWeekAndPlanAllowsMonth_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.StorePeriod = period.PeriodWeek
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodMonth
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenStoragePeriodIsYearAndPlanAllowsMonth_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.StorePeriod = period.PeriodYear
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodMonth
err := config.Validate(plan)
assert.EqualError(t, err, "storage period exceeds plan limit")
}
func Test_Validate_WhenStoragePeriodIsForeverAndPlanAllowsForever_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.StorePeriod = period.PeriodForever
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodForever
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenStoragePeriodIsForeverAndPlanAllowsYear_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.StorePeriod = period.PeriodForever
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodYear
err := config.Validate(plan)
assert.EqualError(t, err, "storage period exceeds plan limit")
}
func Test_Validate_WhenStoragePeriodEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.StorePeriod = period.PeriodMonth
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodMonth
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenBackupSize100MBAndPlanAllows500MB_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = 100
plan := createUnlimitedPlan()
plan.MaxBackupSizeMB = 500
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenBackupSize500MBAndPlanAllows100MB_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = 500
plan := createUnlimitedPlan()
plan.MaxBackupSizeMB = 100
err := config.Validate(plan)
assert.EqualError(t, err, "max backup size exceeds plan limit")
}
func Test_Validate_WhenBackupSizeIsUnlimitedAndPlanAllowsUnlimited_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = 0
plan := createUnlimitedPlan()
plan.MaxBackupSizeMB = 0
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenBackupSizeIsUnlimitedAndPlanHas500MBLimit_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = 0
plan := createUnlimitedPlan()
plan.MaxBackupSizeMB = 500
err := config.Validate(plan)
assert.EqualError(t, err, "max backup size exceeds plan limit")
}
func Test_Validate_WhenBackupSizeEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = 500
plan := createUnlimitedPlan()
plan.MaxBackupSizeMB = 500
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenTotalSize1GBAndPlanAllows5GB_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = 1000
plan := createUnlimitedPlan()
plan.MaxBackupsTotalSizeMB = 5000
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenTotalSize5GBAndPlanAllows1GB_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = 5000
plan := createUnlimitedPlan()
plan.MaxBackupsTotalSizeMB = 1000
err := config.Validate(plan)
assert.EqualError(t, err, "max total backups size exceeds plan limit")
}
func Test_Validate_WhenTotalSizeIsUnlimitedAndPlanAllowsUnlimited_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = 0
plan := createUnlimitedPlan()
plan.MaxBackupsTotalSizeMB = 0
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenTotalSizeIsUnlimitedAndPlanHas1GBLimit_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = 0
plan := createUnlimitedPlan()
plan.MaxBackupsTotalSizeMB = 1000
err := config.Validate(plan)
assert.EqualError(t, err, "max total backups size exceeds plan limit")
}
func Test_Validate_WhenTotalSizeEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = 5000
plan := createUnlimitedPlan()
plan.MaxBackupsTotalSizeMB = 5000
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenAllLimitsAreUnlimitedInPlan_AnyConfigurationPasses(t *testing.T) {
config := createValidBackupConfig()
config.StorePeriod = period.PeriodForever
config.MaxBackupSizeMB = 0
config.MaxBackupsTotalSizeMB = 0
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenMultipleLimitsExceeded_ValidationFailsWithFirstError(t *testing.T) {
config := createValidBackupConfig()
config.StorePeriod = period.PeriodYear
config.MaxBackupSizeMB = 500
config.MaxBackupsTotalSizeMB = 5000
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodMonth
plan.MaxBackupSizeMB = 100
plan.MaxBackupsTotalSizeMB = 1000
err := config.Validate(plan)
assert.Error(t, err)
assert.EqualError(t, err, "storage period exceeds plan limit")
}
func Test_Validate_WhenConfigHasInvalidIntervalButPlanIsValid_ValidationFailsOnInterval(
t *testing.T,
) {
config := createValidBackupConfig()
config.BackupIntervalID = uuid.Nil
config.BackupInterval = nil
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "backup interval is required")
}
func Test_Validate_WhenIntervalIsMissing_ValidationFailsRegardlessOfPlan(t *testing.T) {
config := createValidBackupConfig()
config.BackupIntervalID = uuid.Nil
config.BackupInterval = nil
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "backup interval is required")
}
func Test_Validate_WhenRetryEnabledButMaxTriesIsZero_ValidationFailsRegardlessOfPlan(t *testing.T) {
config := createValidBackupConfig()
config.IsRetryIfFailed = true
config.MaxFailedTriesCount = 0
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "max failed tries count must be greater than 0")
}
func Test_Validate_WhenEncryptionIsInvalid_ValidationFailsRegardlessOfPlan(t *testing.T) {
config := createValidBackupConfig()
config.Encryption = "INVALID"
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "encryption must be NONE or ENCRYPTED")
}
func Test_Validate_WhenStoragePeriodIsEmpty_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.StorePeriod = ""
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "store period is required")
}
func Test_Validate_WhenMaxBackupSizeIsNegative_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = -100
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "max backup size must be non-negative")
}
func Test_Validate_WhenMaxTotalSizeIsNegative_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = -1000
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "max backups total size must be non-negative")
}
func Test_Validate_WhenPlanLimitsAreAtBoundary_ValidationWorks(t *testing.T) {
tests := []struct {
name string
configPeriod period.Period
planPeriod period.Period
configSize int64
planSize int64
configTotal int64
planTotal int64
shouldSucceed bool
}{
{
name: "all values just under limit",
configPeriod: period.PeriodWeek,
planPeriod: period.PeriodMonth,
configSize: 99,
planSize: 100,
configTotal: 999,
planTotal: 1000,
shouldSucceed: true,
},
{
name: "all values equal to limit",
configPeriod: period.PeriodMonth,
planPeriod: period.PeriodMonth,
configSize: 100,
planSize: 100,
configTotal: 1000,
planTotal: 1000,
shouldSucceed: true,
},
{
name: "period just over limit",
configPeriod: period.Period3Month,
planPeriod: period.PeriodMonth,
configSize: 100,
planSize: 100,
configTotal: 1000,
planTotal: 1000,
shouldSucceed: false,
},
{
name: "size just over limit",
configPeriod: period.PeriodMonth,
planPeriod: period.PeriodMonth,
configSize: 101,
planSize: 100,
configTotal: 1000,
planTotal: 1000,
shouldSucceed: false,
},
{
name: "total size just over limit",
configPeriod: period.PeriodMonth,
planPeriod: period.PeriodMonth,
configSize: 100,
planSize: 100,
configTotal: 1001,
planTotal: 1000,
shouldSucceed: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := createValidBackupConfig()
config.StorePeriod = tt.configPeriod
config.MaxBackupSizeMB = tt.configSize
config.MaxBackupsTotalSizeMB = tt.configTotal
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = tt.planPeriod
plan.MaxBackupSizeMB = tt.planSize
plan.MaxBackupsTotalSizeMB = tt.planTotal
err := config.Validate(plan)
if tt.shouldSucceed {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
})
}
}
func createValidBackupConfig() *BackupConfig {
intervalID := uuid.New()
return &BackupConfig{
DatabaseID: uuid.New(),
IsBackupsEnabled: true,
StorePeriod: period.PeriodMonth,
BackupIntervalID: intervalID,
BackupInterval: &intervals.Interval{ID: intervalID},
SendNotificationsOn: []BackupNotificationType{},
IsRetryIfFailed: false,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 100,
MaxBackupsTotalSizeMB: 1000,
}
}
func createUnlimitedPlan() *plans.DatabasePlan {
return &plans.DatabasePlan{
DatabaseID: uuid.New(),
MaxBackupSizeMB: 0,
MaxBackupsTotalSizeMB: 0,
MaxStoragePeriod: period.PeriodForever,
}
}

View File

@@ -26,6 +26,12 @@ func Test_AttachNotifierFromSameWorkspace_SuccessfullyAttached(t *testing.T) {
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
notifier := notifiers.CreateTestNotifier(workspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
database.Notifiers = []notifiers.Notifier{*notifier}
var response databases.Database
@@ -55,6 +61,13 @@ func Test_AttachNotifierFromDifferentWorkspace_ReturnsForbidden(t *testing.T) {
workspace2 := workspaces_testing.CreateTestWorkspace("Workspace 2", owner2, router)
notifier := notifiers.CreateTestNotifier(workspace2.ID)
defer func() {
databases.RemoveTestDatabase(database)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace1, router)
workspaces_testing.RemoveTestWorkspace(workspace2, router)
}()
database.Notifiers = []notifiers.Notifier{*notifier}
testResp := test_utils.MakePostRequest(
@@ -77,6 +90,12 @@ func Test_DeleteNotifierWithAttachedDatabases_CannotDelete(t *testing.T) {
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
notifier := notifiers.CreateTestNotifier(workspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
database.Notifiers = []notifiers.Notifier{*notifier}
var response databases.Database
@@ -114,6 +133,13 @@ func Test_TransferNotifierWithAttachedDatabase_CannotTransfer(t *testing.T) {
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
notifier := notifiers.CreateTestNotifier(workspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
database.Notifiers = []notifiers.Notifier{*notifier}
var response databases.Database

View File

@@ -6,10 +6,10 @@ import (
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/notifiers"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/features/storages"
users_models "databasus-backend/internal/features/users/models"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/period"
"github.com/google/uuid"
)
@@ -20,6 +20,7 @@ type BackupConfigService struct {
storageService *storages.StorageService
notifierService *notifiers.NotifierService
workspaceService *workspaces_services.WorkspaceService
databasePlanService *plans.DatabasePlanService
dbStorageChangeListener BackupConfigStorageChangeListener
}
@@ -45,7 +46,12 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
user *users_models.User,
backupConfig *BackupConfig,
) (*BackupConfig, error) {
if err := backupConfig.Validate(); err != nil {
plan, err := s.databasePlanService.GetDatabasePlan(backupConfig.DatabaseID)
if err != nil {
return nil, err
}
if err := backupConfig.Validate(plan); err != nil {
return nil, err
}
@@ -71,7 +77,7 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
if err != nil {
return nil, err
}
if storage.WorkspaceID != *database.WorkspaceID {
if storage.WorkspaceID != *database.WorkspaceID && !storage.IsSystem {
return nil, errors.New("storage does not belong to the same workspace as the database")
}
}
@@ -82,7 +88,12 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
func (s *BackupConfigService) SaveBackupConfig(
backupConfig *BackupConfig,
) (*BackupConfig, error) {
if err := backupConfig.Validate(); err != nil {
plan, err := s.databasePlanService.GetDatabasePlan(backupConfig.DatabaseID)
if err != nil {
return nil, err
}
if err := backupConfig.Validate(plan); err != nil {
return nil, err
}
@@ -120,6 +131,18 @@ func (s *BackupConfigService) GetBackupConfigByDbIdWithAuth(
return s.GetBackupConfigByDbId(databaseID)
}
func (s *BackupConfigService) GetDatabasePlan(
user *users_models.User,
databaseID uuid.UUID,
) (*plans.DatabasePlan, error) {
_, err := s.databaseService.GetDatabase(user, databaseID)
if err != nil {
return nil, err
}
return s.databasePlanService.GetDatabasePlan(databaseID)
}
func (s *BackupConfigService) GetBackupConfigByDbId(
databaseID uuid.UUID,
) (*BackupConfig, error) {
@@ -194,12 +217,19 @@ func (s *BackupConfigService) CreateDisabledBackupConfig(databaseID uuid.UUID) e
func (s *BackupConfigService) initializeDefaultConfig(
databaseID uuid.UUID,
) error {
plan, err := s.databasePlanService.GetDatabasePlan(databaseID)
if err != nil {
return err
}
timeOfDay := "04:00"
_, err := s.backupConfigRepository.Save(&BackupConfig{
DatabaseID: databaseID,
IsBackupsEnabled: false,
StorePeriod: period.PeriodWeek,
_, err = s.backupConfigRepository.Save(&BackupConfig{
DatabaseID: databaseID,
IsBackupsEnabled: false,
StorePeriod: plan.MaxStoragePeriod,
MaxBackupSizeMB: plan.MaxBackupSizeMB,
MaxBackupsTotalSizeMB: plan.MaxBackupsTotalSizeMB,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,

View File

@@ -27,6 +27,12 @@ func Test_AttachStorageFromSameWorkspace_SuccessfullyAttached(t *testing.T) {
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
@@ -72,6 +78,13 @@ func Test_AttachStorageFromDifferentWorkspace_ReturnsForbidden(t *testing.T) {
workspace2 := workspaces_testing.CreateTestWorkspace("Workspace 2", owner2, router)
storage := createTestStorage(workspace2.ID)
defer func() {
databases.RemoveTestDatabase(database)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace1, router)
workspaces_testing.RemoveTestWorkspace(workspace2, router)
}()
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
@@ -110,6 +123,12 @@ func Test_DeleteStorageWithAttachedDatabases_CannotDelete(t *testing.T) {
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
@@ -163,6 +182,13 @@ func Test_TransferStorageWithAttachedDatabase_CannotTransfer(t *testing.T) {
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,

View File

@@ -25,80 +25,6 @@ import (
"databasus-backend/internal/util/tools"
)
func createTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
GetDatabaseController(),
)
return router
}
func getTestPostgresConfig() *postgresql.PostgresqlDatabase {
env := config.GetEnv()
port, err := strconv.Atoi(env.TestPostgres16Port)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
}
testDbName := "testdb"
return &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
CpuCount: 1,
}
}
func getTestMariadbConfig() *mariadb.MariadbDatabase {
env := config.GetEnv()
portStr := env.TestMariadb1011Port
if portStr == "" {
portStr = "33111"
}
port, err := strconv.Atoi(portStr)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_MARIADB_1011_PORT: %v", err))
}
testDbName := "testdb"
return &mariadb.MariadbDatabase{
Version: tools.MariadbVersion1011,
Host: "localhost",
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
}
}
func getTestMongodbConfig() *mongodb.MongodbDatabase {
env := config.GetEnv()
portStr := env.TestMongodb70Port
if portStr == "" {
portStr = "27070"
}
port, err := strconv.Atoi(portStr)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_MONGODB_70_PORT: %v", err))
}
return &mongodb.MongodbDatabase{
Version: tools.MongodbVersion7,
Host: "localhost",
Port: port,
Username: "root",
Password: "rootpassword",
Database: "testdb",
AuthDatabase: "admin",
IsHttps: false,
CpuCount: 1,
}
}
func Test_CreateDatabase_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
@@ -142,6 +68,7 @@ func Test_CreateDatabase_PermissionsEnforced(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
var testUserToken string
if tt.isGlobalAdmin {
@@ -180,6 +107,7 @@ func Test_CreateDatabase_PermissionsEnforced(t *testing.T) {
)
if tt.expectSuccess {
defer RemoveTestDatabase(&response)
assert.Equal(t, "Test Database", response.Name)
assert.NotEqual(t, uuid.Nil, response.ID)
} else {
@@ -193,6 +121,7 @@ func Test_CreateDatabase_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testin
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
@@ -258,8 +187,10 @@ func Test_UpdateDatabase_PermissionsEnforced(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(database)
var testUserToken string
if tt.isGlobalAdmin {
@@ -305,8 +236,10 @@ func Test_UpdateDatabase_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testin
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(database)
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
database.Name = "Hacked Name"
@@ -366,6 +299,7 @@ func Test_DeleteDatabase_PermissionsEnforced(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
@@ -396,6 +330,7 @@ func Test_DeleteDatabase_PermissionsEnforced(t *testing.T) {
)
if !tt.expectSuccess {
defer RemoveTestDatabase(database)
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
})
@@ -439,8 +374,10 @@ func Test_GetDatabase_PermissionsEnforced(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(database)
var testUser string
if tt.isGlobalAdmin {
@@ -517,9 +454,12 @@ func Test_GetDatabasesByWorkspace_PermissionsEnforced(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
createTestDatabaseViaAPI("Database 1", workspace.ID, owner.Token, router)
createTestDatabaseViaAPI("Database 2", workspace.ID, owner.Token, router)
db1 := createTestDatabaseViaAPI("Database 1", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(db1)
db2 := createTestDatabaseViaAPI("Database 2", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(db2)
var testUser string
if tt.isGlobalAdmin {
@@ -561,10 +501,14 @@ func Test_GetDatabasesByWorkspace_WhenMultipleDatabasesExist_ReturnsCorrectCount
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
createTestDatabaseViaAPI("Database 1", workspace.ID, owner.Token, router)
createTestDatabaseViaAPI("Database 2", workspace.ID, owner.Token, router)
createTestDatabaseViaAPI("Database 3", workspace.ID, owner.Token, router)
db1 := createTestDatabaseViaAPI("Database 1", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(db1)
db2 := createTestDatabaseViaAPI("Database 2", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(db2)
db3 := createTestDatabaseViaAPI("Database 3", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(db3)
var response []Database
test_utils.MakeGetRequestAndUnmarshal(
@@ -583,14 +527,19 @@ func Test_GetDatabasesByWorkspace_EnsuresCrossWorkspaceIsolation(t *testing.T) {
router := createTestRouter()
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace1 := workspaces_testing.CreateTestWorkspace("Workspace 1", owner1, router)
defer workspaces_testing.RemoveTestWorkspace(workspace1, router)
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace2 := workspaces_testing.CreateTestWorkspace("Workspace 2", owner2, router)
defer workspaces_testing.RemoveTestWorkspace(workspace2, router)
createTestDatabaseViaAPI("Workspace1 DB1", workspace1.ID, owner1.Token, router)
createTestDatabaseViaAPI("Workspace1 DB2", workspace1.ID, owner1.Token, router)
workspace1Db1 := createTestDatabaseViaAPI("Workspace1 DB1", workspace1.ID, owner1.Token, router)
defer RemoveTestDatabase(workspace1Db1)
workspace1Db2 := createTestDatabaseViaAPI("Workspace1 DB2", workspace1.ID, owner1.Token, router)
defer RemoveTestDatabase(workspace1Db2)
createTestDatabaseViaAPI("Workspace2 DB1", workspace2.ID, owner2.Token, router)
workspace2Db1 := createTestDatabaseViaAPI("Workspace2 DB1", workspace2.ID, owner2.Token, router)
defer RemoveTestDatabase(workspace2Db1)
var workspace1Dbs []Database
test_utils.MakeGetRequestAndUnmarshal(
@@ -667,8 +616,10 @@ func Test_CopyDatabase_PermissionsEnforced(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(database)
var testUserToken string
if tt.isGlobalAdmin {
@@ -700,6 +651,7 @@ func Test_CopyDatabase_PermissionsEnforced(t *testing.T) {
)
if tt.expectSuccess {
defer RemoveTestDatabase(&response)
assert.NotEqual(t, database.ID, response.ID)
assert.Contains(t, response.Name, "(Copy)")
} else {
@@ -713,8 +665,10 @@ func Test_CopyDatabase_CopyStaysInSameWorkspace(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(database)
var response Database
test_utils.MakePostRequestAndUnmarshal(
@@ -727,139 +681,14 @@ func Test_CopyDatabase_CopyStaysInSameWorkspace(t *testing.T) {
&response,
)
defer RemoveTestDatabase(&response)
assert.NotEqual(t, database.ID, response.ID)
assert.Equal(t, "Test Database (Copy)", response.Name)
assert.Equal(t, workspace.ID, *response.WorkspaceID)
assert.Equal(t, database.Type, response.Type)
}
func Test_TestConnection_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
isMember bool
isGlobalAdmin bool
expectAccessGranted bool
expectedStatusCodeOnErr int
}{
{
name: "workspace member can test connection",
isMember: true,
isGlobalAdmin: false,
expectAccessGranted: true,
expectedStatusCodeOnErr: http.StatusBadRequest,
},
{
name: "non-member cannot test connection",
isMember: false,
isGlobalAdmin: false,
expectAccessGranted: false,
expectedStatusCodeOnErr: http.StatusBadRequest,
},
{
name: "global admin can test connection",
isMember: false,
isGlobalAdmin: true,
expectAccessGranted: true,
expectedStatusCodeOnErr: http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
var testUser string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUser = admin.Token
} else if tt.isMember {
testUser = owner.Token
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
testUser = nonMember.Token
}
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/"+database.ID.String()+"/test-connection",
"Bearer "+testUser,
nil,
)
body := w.Body.String()
if tt.expectAccessGranted {
assert.True(
t,
w.Code == http.StatusOK ||
(w.Code == http.StatusBadRequest && strings.Contains(body, "connect")),
"Expected 200 OK or 400 with connection error, got %d: %s",
w.Code,
body,
)
} else {
assert.Equal(t, tt.expectedStatusCodeOnErr, w.Code)
assert.Contains(t, body, "insufficient permissions")
}
})
}
}
func createTestDatabaseViaAPI(
name string,
workspaceID uuid.UUID,
token string,
router *gin.Engine,
) *Database {
env := config.GetEnv()
port, err := strconv.Atoi(env.TestPostgres16Port)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
}
testDbName := "testdb"
request := Database{
Name: name,
WorkspaceID: &workspaceID,
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
CpuCount: 1,
},
}
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/create",
"Bearer "+token,
request,
)
if w.Code != http.StatusCreated {
panic(
fmt.Sprintf("Failed to create database. Status: %d, Body: %s", w.Code, w.Body.String()),
)
}
var database Database
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
panic(err)
}
return &database
}
func Test_CreateDatabase_PasswordIsEncryptedInDB(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
@@ -1141,3 +970,206 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
})
}
}
func Test_TestConnection_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
isMember bool
isGlobalAdmin bool
expectAccessGranted bool
expectedStatusCodeOnErr int
}{
{
name: "workspace member can test connection",
isMember: true,
isGlobalAdmin: false,
expectAccessGranted: true,
expectedStatusCodeOnErr: http.StatusBadRequest,
},
{
name: "non-member cannot test connection",
isMember: false,
isGlobalAdmin: false,
expectAccessGranted: false,
expectedStatusCodeOnErr: http.StatusBadRequest,
},
{
name: "global admin can test connection",
isMember: false,
isGlobalAdmin: true,
expectAccessGranted: true,
expectedStatusCodeOnErr: http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(database)
var testUser string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUser = admin.Token
} else if tt.isMember {
testUser = owner.Token
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
testUser = nonMember.Token
}
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/"+database.ID.String()+"/test-connection",
"Bearer "+testUser,
nil,
)
body := w.Body.String()
if tt.expectAccessGranted {
assert.True(
t,
w.Code == http.StatusOK ||
(w.Code == http.StatusBadRequest && strings.Contains(body, "connect")),
"Expected 200 OK or 400 with connection error, got %d: %s",
w.Code,
body,
)
} else {
assert.Equal(t, tt.expectedStatusCodeOnErr, w.Code)
assert.Contains(t, body, "insufficient permissions")
}
})
}
}
func createTestDatabaseViaAPI(
name string,
workspaceID uuid.UUID,
token string,
router *gin.Engine,
) *Database {
env := config.GetEnv()
port, err := strconv.Atoi(env.TestPostgres16Port)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
}
testDbName := "testdb"
request := Database{
Name: name,
WorkspaceID: &workspaceID,
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
CpuCount: 1,
},
}
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/create",
"Bearer "+token,
request,
)
if w.Code != http.StatusCreated {
panic(
fmt.Sprintf("Failed to create database. Status: %d, Body: %s", w.Code, w.Body.String()),
)
}
var database Database
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
panic(err)
}
return &database
}
func createTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
GetDatabaseController(),
)
return router
}
func getTestPostgresConfig() *postgresql.PostgresqlDatabase {
env := config.GetEnv()
port, err := strconv.Atoi(env.TestPostgres16Port)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
}
testDbName := "testdb"
return &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
CpuCount: 1,
}
}
func getTestMariadbConfig() *mariadb.MariadbDatabase {
env := config.GetEnv()
portStr := env.TestMariadb1011Port
if portStr == "" {
portStr = "33111"
}
port, err := strconv.Atoi(portStr)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_MARIADB_1011_PORT: %v", err))
}
testDbName := "testdb"
return &mariadb.MariadbDatabase{
Version: tools.MariadbVersion1011,
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
}
}
func getTestMongodbConfig() *mongodb.MongodbDatabase {
env := config.GetEnv()
portStr := env.TestMongodb70Port
if portStr == "" {
portStr = "27070"
}
port, err := strconv.Atoi(portStr)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_MONGODB_70_PORT: %v", err))
}
return &mongodb.MongodbDatabase{
Version: tools.MongodbVersion7,
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "root",
Password: "rootpassword",
Database: "testdb",
AuthDatabase: "admin",
IsHttps: false,
CpuCount: 1,
}
}

View File

@@ -515,9 +515,13 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
hasProcess := false
hasAllPrivileges := false
// Escape underscores to match MariaDB's grant output format
// MariaDB escapes _ as \_ in SHOW GRANTS output
// Pattern matches either literal _ or escaped \_
escapedDbName := strings.ReplaceAll(regexp.QuoteMeta(database), "_", `(_|\\_)`)
dbPatternStr := fmt.Sprintf(
`(?i)ON\s+[\x60'"]?%s[\x60'"]?\s*\.\s*\*`,
regexp.QuoteMeta(database),
escapedDbName,
)
dbPattern := regexp.MustCompile(dbPatternStr)
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\s*\.\s*\*`)

View File

@@ -694,6 +694,115 @@ func Test_TestConnection_DatabaseWithUnderscores_Success(t *testing.T) {
assert.NoError(t, err)
}
func Test_TestConnection_DatabaseWithUnderscoresAndAllPrivileges_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MariadbVersion
port string
}{
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMariadbContainer(t, tc.port, tc.version)
defer container.DB.Close()
underscoreDbName := "test_all_db"
_, err := container.DB.Exec(
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
)
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", underscoreDbName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
)
}()
underscoreDSN := fmt.Sprintf(
"%s:%s@tcp(%s:%d)/%s?parseTime=true",
container.Username,
container.Password,
container.Host,
container.Port,
underscoreDbName,
)
underscoreDB, err := sqlx.Connect("mysql", underscoreDSN)
assert.NoError(t, err)
defer underscoreDB.Close()
_, err = underscoreDB.Exec(`
CREATE TABLE all_priv_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)
`)
assert.NoError(t, err)
_, err = underscoreDB.Exec(`INSERT INTO all_priv_test (data) VALUES ('test1')`)
assert.NoError(t, err)
allPrivUsername := fmt.Sprintf("allpriv%s", uuid.New().String()[:8])
allPrivPassword := "allprivpass123"
_, err = underscoreDB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
allPrivUsername,
allPrivPassword,
))
assert.NoError(t, err)
_, err = underscoreDB.Exec(fmt.Sprintf(
"GRANT ALL PRIVILEGES ON `%s`.* TO '%s'@'%%'",
underscoreDbName,
allPrivUsername,
))
assert.NoError(t, err)
_, err = underscoreDB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer dropUserSafe(underscoreDB, allPrivUsername)
mariadbModel := &MariadbDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: allPrivUsername,
Password: allPrivPassword,
Database: &underscoreDbName,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mariadbModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
assert.NotEmpty(t, mariadbModel.Privileges)
assert.Contains(t, mariadbModel.Privileges, "SELECT")
assert.Contains(t, mariadbModel.Privileges, "SHOW VIEW")
})
}
}
type MariadbContainer struct {
Host string
Port int
@@ -714,7 +823,7 @@ func connectToMariadbContainer(
}
dbName := "testdb"
host := "127.0.0.1"
host := config.GetEnv().TestLocalhost
username := "root"
password := "rootpassword"

View File

@@ -9,6 +9,7 @@ import (
"strconv"
"strings"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
@@ -397,7 +398,7 @@ func connectToMongodbContainer(
}
dbName := "testdb"
host := "127.0.0.1"
host := config.GetEnv().TestLocalhost
username := "root"
password := "rootpassword"
authDatabase := "admin"
@@ -406,11 +407,18 @@ func connectToMongodbContainer(
assert.NoError(t, err)
uri := fmt.Sprintf(
"mongodb://%s:%s@%s:%d/%s?authSource=%s",
username, password, host, portInt, dbName, authDatabase,
"mongodb://%s:%s@%s:%d/%s?authSource=%s&serverSelectionTimeoutMS=5000&connectTimeoutMS=5000",
username,
password,
host,
portInt,
dbName,
authDatabase,
)
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
clientOptions := options.Client().ApplyURI(uri)
client, err := mongo.Connect(ctx, clientOptions)
if err != nil {

View File

@@ -489,9 +489,13 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
hasProcess := false
hasAllPrivileges := false
// Escape underscores to match MySQL's grant output format
// MySQL escapes _ as \_ in SHOW GRANTS output
// Pattern matches either literal _ or escaped \_
escapedDbName := strings.ReplaceAll(regexp.QuoteMeta(database), "_", `(_|\\_)`)
dbPatternStr := fmt.Sprintf(
`(?i)ON\s+[\x60'"]?%s[\x60'"]?\s*\.\s*\*`,
regexp.QuoteMeta(database),
escapedDbName,
)
dbPattern := regexp.MustCompile(dbPatternStr)
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\s*\.\s*\*`)

View File

@@ -674,6 +674,112 @@ func Test_TestConnection_DatabaseWithUnderscores_Success(t *testing.T) {
assert.NoError(t, err)
}
func Test_TestConnection_DatabaseWithUnderscoresAndAllPrivileges_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MysqlVersion
port string
}{
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port},
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port},
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port},
{"MySQL 9", tools.MysqlVersion9, env.TestMysql90Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMysqlContainer(t, tc.port, tc.version)
defer container.DB.Close()
underscoreDbName := "test_all_db"
_, err := container.DB.Exec(
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
)
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", underscoreDbName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
)
}()
underscoreDSN := fmt.Sprintf(
"%s:%s@tcp(%s:%d)/%s?parseTime=true",
container.Username,
container.Password,
container.Host,
container.Port,
underscoreDbName,
)
underscoreDB, err := sqlx.Connect("mysql", underscoreDSN)
assert.NoError(t, err)
defer underscoreDB.Close()
_, err = underscoreDB.Exec(`
CREATE TABLE all_priv_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)
`)
assert.NoError(t, err)
_, err = underscoreDB.Exec(`INSERT INTO all_priv_test (data) VALUES ('test1')`)
assert.NoError(t, err)
allPrivUsername := fmt.Sprintf("allpriv_%s", uuid.New().String()[:8])
allPrivPassword := "allprivpass123"
_, err = underscoreDB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
allPrivUsername,
allPrivPassword,
))
assert.NoError(t, err)
_, err = underscoreDB.Exec(fmt.Sprintf(
"GRANT ALL PRIVILEGES ON `%s`.* TO '%s'@'%%'",
underscoreDbName,
allPrivUsername,
))
assert.NoError(t, err)
_, err = underscoreDB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(
fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", allPrivUsername),
)
}()
mysqlModel := &MysqlDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: allPrivUsername,
Password: allPrivPassword,
Database: &underscoreDbName,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mysqlModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
assert.NotEmpty(t, mysqlModel.Privileges)
assert.Contains(t, mysqlModel.Privileges, "SELECT")
assert.Contains(t, mysqlModel.Privileges, "SHOW VIEW")
})
}
}
type MysqlContainer struct {
Host string
Port int
@@ -694,7 +800,7 @@ func connectToMysqlContainer(
}
dbName := "testdb"
host := "127.0.0.1"
host := config.GetEnv().TestLocalhost
username := "root"
password := "rootpassword"

View File

@@ -13,6 +13,7 @@ import (
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"gorm.io/gorm"
)
@@ -394,10 +395,13 @@ func (p *PostgresqlDatabase) IsUserReadOnly(
//
// This method performs the following operations atomically in a single transaction:
// 1. Creates a PostgreSQL user with a UUID-based password
// 2. Grants CONNECT privilege on the database
// 3. Grants USAGE on all non-system schemas
// 4. Grants SELECT on all existing tables and sequences
// 5. Sets default privileges for future tables and sequences
// 2. Revokes CREATE privilege on public schema from PUBLIC role
// 3. Grants CONNECT privilege on the database
// 4. Discovers all user-created schemas
// 5. Grants USAGE on all non-system schemas
// 6. Grants SELECT on all existing tables and sequences
// 7. Sets default privileges for future tables and sequences
// 8. Verifies user creation before committing
//
// Security features:
// - Username format: "databasus-{8-char-uuid}" for uniqueness
@@ -487,33 +491,56 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
return "", "", fmt.Errorf("failed to create user: %w", err)
}
// Step 1.5: Revoke CREATE privilege from PUBLIC role on public schema
// Step 2: Check if public schema exists and revoke CREATE privilege if it does
// This is necessary because all PostgreSQL users inherit CREATE privilege on the
// public schema through the PUBLIC role. This is a one-time operation that affects
// the entire database, making it more secure by default.
// Note: This only affects the public schema; other schemas are unaffected.
_, err = tx.Exec(ctx, `REVOKE CREATE ON SCHEMA public FROM PUBLIC`)
if err != nil {
logger.Error("Failed to revoke CREATE on public from PUBLIC", "error", err)
if !strings.Contains(err.Error(), "schema \"public\" does not exist") &&
!strings.Contains(err.Error(), "permission denied") {
return "", "", fmt.Errorf("failed to revoke CREATE from PUBLIC: %w", err)
}
}
// Now revoke from the specific user as well (belt and suspenders)
_, err = tx.Exec(ctx, fmt.Sprintf(`REVOKE CREATE ON SCHEMA public FROM "%s"`, baseUsername))
if err != nil {
logger.Error(
"Failed to revoke CREATE on public schema from user",
"error",
err,
"username",
baseUsername,
var publicSchemaExists bool
err = tx.QueryRow(ctx, `
SELECT EXISTS(
SELECT 1 FROM information_schema.schemata
WHERE schema_name = 'public'
)
`).Scan(&publicSchemaExists)
if err != nil {
return "", "", fmt.Errorf("failed to check if public schema exists: %w", err)
}
// Step 2: Grant database connection privilege and revoke TEMP
if publicSchemaExists {
// Revoke CREATE from PUBLIC role (affects all users)
_, err = tx.Exec(ctx, `REVOKE CREATE ON SCHEMA public FROM PUBLIC`)
if err != nil {
if strings.Contains(err.Error(), "permission denied") {
logger.Warn(
"Failed to revoke CREATE on public from PUBLIC (permission denied)",
"error",
err,
)
} else {
return "", "", fmt.Errorf("failed to revoke CREATE from PUBLIC on existing public schema: %w", err)
}
}
// Now revoke from the specific user as well (belt and suspenders)
_, err = tx.Exec(
ctx,
fmt.Sprintf(`REVOKE CREATE ON SCHEMA public FROM "%s"`, baseUsername),
)
if err != nil {
logger.Warn(
"Failed to revoke CREATE on public schema from user",
"error",
err,
"username",
baseUsername,
)
}
} else {
logger.Info("Public schema does not exist, skipping CREATE privilege revocation")
}
// Step 3: Grant database connection privilege and revoke TEMP
_, err = tx.Exec(
ctx,
fmt.Sprintf(`GRANT CONNECT ON DATABASE "%s" TO "%s"`, *p.Database, baseUsername),
@@ -537,7 +564,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
logger.Warn("Failed to revoke TEMP privilege", "error", err, "username", baseUsername)
}
// Step 3: Discover all user-created schemas
// Step 4: Discover all user-created schemas
rows, err := tx.Query(ctx, `
SELECT schema_name
FROM information_schema.schemata
@@ -562,7 +589,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
return "", "", fmt.Errorf("error iterating schemas: %w", err)
}
// Step 4: Grant USAGE on each schema and explicitly prevent CREATE
// Step 5: Grant USAGE on each schema and explicitly prevent CREATE
for _, schema := range schemas {
// Revoke CREATE specifically (handles inheritance from PUBLIC role)
_, err = tx.Exec(
@@ -591,7 +618,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
}
}
// Step 5: Grant SELECT on ALL existing tables and sequences
// Step 6: Grant SELECT on ALL existing tables and sequences
grantSelectSQL := fmt.Sprintf(`
DO $$
DECLARE
@@ -613,7 +640,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
return "", "", fmt.Errorf("failed to grant select on tables: %w", err)
}
// Step 6: Set default privileges for FUTURE tables and sequences
// Step 7: Set default privileges for FUTURE tables and sequences
defaultPrivilegesSQL := fmt.Sprintf(`
DO $$
DECLARE
@@ -635,7 +662,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
return "", "", fmt.Errorf("failed to set default privileges: %w", err)
}
// Step 7: Verify user creation before committing
// Step 8: Verify user creation before committing
var verifyUsername string
err = tx.QueryRow(ctx, fmt.Sprintf(`SELECT rolname FROM pg_roles WHERE rolname = '%s'`, baseUsername)).
Scan(&verifyUsername)
@@ -851,7 +878,15 @@ func checkBackupPermissions(
}
if err != nil {
return fmt.Errorf("cannot check SELECT privileges: %w", err)
// If the user doesn't have USAGE on the schema, has_table_privilege will fail
// with "permission denied for schema". This means they definitely don't have
// SELECT privileges, so treat this as missing permissions rather than an error.
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) && pgErr.Code == "42501" { // insufficient_privilege
selectableTableCount = 0
} else {
return fmt.Errorf("cannot check SELECT privileges: %w", err)
}
}
if selectableTableCount == 0 {
missingPrivileges = append(missingPrivileges, "SELECT on tables")

View File

@@ -599,6 +599,10 @@ func Test_CreateReadOnlyUser_DatabaseNameWithDash_Success(t *testing.T) {
}
func Test_CreateReadOnlyUser_Supabase_UserCanReadButNotWrite(t *testing.T) {
if config.GetEnv().IsSkipExternalResourcesTests {
t.Skip("Skipping Supabase test: IS_SKIP_EXTERNAL_RESOURCES_TESTS is true")
}
env := config.GetEnv()
if env.TestSupabaseHost == "" {
@@ -705,6 +709,344 @@ func Test_CreateReadOnlyUser_Supabase_UserCanReadButNotWrite(t *testing.T) {
assert.Contains(t, err.Error(), "permission denied")
}
func Test_CreateReadOnlyUser_WithPublicSchema_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version string
port string
}{
{"PostgreSQL 12", "12", env.TestPostgres12Port},
{"PostgreSQL 13", "13", env.TestPostgres13Port},
{"PostgreSQL 14", "14", env.TestPostgres14Port},
{"PostgreSQL 15", "15", env.TestPostgres15Port},
{"PostgreSQL 16", "16", env.TestPostgres16Port},
{"PostgreSQL 17", "17", env.TestPostgres17Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToPostgresContainer(t, tc.port)
defer container.DB.Close()
_, err := container.DB.Exec(`
DROP TABLE IF EXISTS public_schema_test CASCADE;
CREATE TABLE public_schema_test (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO public_schema_test (data) VALUES ('test1'), ('test2');
`)
assert.NoError(t, err)
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.NotEmpty(t, username)
assert.NotEmpty(t, password)
assert.True(t, strings.HasPrefix(username, "databasus-"))
readOnlyModel := &PostgresqlDatabase{
Version: pgModel.Version,
Host: pgModel.Host,
Port: pgModel.Port,
Username: username,
Password: password,
Database: pgModel.Database,
IsHttps: false,
}
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
ctx,
logger,
nil,
uuid.New(),
)
assert.NoError(t, err)
assert.True(t, isReadOnly, "User should be read-only")
assert.Empty(t, privileges, "Read-only user should have no write privileges")
readOnlyDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
username,
password,
container.Database,
)
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
var count int
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM public_schema_test")
assert.NoError(t, err)
assert.Equal(t, 2, count)
_, err = readOnlyConn.Exec(
"INSERT INTO public_schema_test (data) VALUES ('should-fail')",
)
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = readOnlyConn.Exec("CREATE TABLE public.hack_table (id INT)")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
if err != nil {
t.Logf("Warning: Failed to drop owned objects: %v", err)
}
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
assert.NoError(t, err)
})
}
}
func Test_CreateReadOnlyUser_WithoutPublicSchema_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version string
port string
}{
{"PostgreSQL 12", "12", env.TestPostgres12Port},
{"PostgreSQL 13", "13", env.TestPostgres13Port},
{"PostgreSQL 14", "14", env.TestPostgres14Port},
{"PostgreSQL 15", "15", env.TestPostgres15Port},
{"PostgreSQL 16", "16", env.TestPostgres16Port},
{"PostgreSQL 17", "17", env.TestPostgres17Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToPostgresContainer(t, tc.port)
defer container.DB.Close()
_, err := container.DB.Exec(`
DROP SCHEMA IF EXISTS public CASCADE;
DROP SCHEMA IF EXISTS app_schema CASCADE;
DROP SCHEMA IF EXISTS data_schema CASCADE;
CREATE SCHEMA app_schema;
CREATE SCHEMA data_schema;
CREATE TABLE app_schema.users (
id SERIAL PRIMARY KEY,
username TEXT NOT NULL
);
CREATE TABLE data_schema.records (
id SERIAL PRIMARY KEY,
info TEXT NOT NULL
);
INSERT INTO app_schema.users (username) VALUES ('user1'), ('user2');
INSERT INTO data_schema.records (info) VALUES ('record1'), ('record2');
`)
assert.NoError(t, err)
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err, "CreateReadOnlyUser should succeed without public schema")
assert.NotEmpty(t, username)
assert.NotEmpty(t, password)
assert.True(t, strings.HasPrefix(username, "databasus-"))
readOnlyModel := &PostgresqlDatabase{
Version: pgModel.Version,
Host: pgModel.Host,
Port: pgModel.Port,
Username: username,
Password: password,
Database: pgModel.Database,
IsHttps: false,
}
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
ctx,
logger,
nil,
uuid.New(),
)
assert.NoError(t, err)
assert.True(t, isReadOnly, "User should be read-only")
assert.Empty(t, privileges, "Read-only user should have no write privileges")
readOnlyDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
username,
password,
container.Database,
)
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
var userCount int
err = readOnlyConn.Get(&userCount, "SELECT COUNT(*) FROM app_schema.users")
assert.NoError(t, err)
assert.Equal(t, 2, userCount)
var recordCount int
err = readOnlyConn.Get(&recordCount, "SELECT COUNT(*) FROM data_schema.records")
assert.NoError(t, err)
assert.Equal(t, 2, recordCount)
_, err = readOnlyConn.Exec(
"INSERT INTO app_schema.users (username) VALUES ('should-fail')",
)
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = readOnlyConn.Exec("CREATE TABLE app_schema.hack_table (id INT)")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = readOnlyConn.Exec("CREATE TABLE data_schema.hack_table (id INT)")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
if err != nil {
t.Logf("Warning: Failed to drop owned objects: %v", err)
}
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
assert.NoError(t, err)
_, err = container.DB.Exec(`
DROP SCHEMA IF EXISTS app_schema CASCADE;
DROP SCHEMA IF EXISTS data_schema CASCADE;
CREATE SCHEMA IF NOT EXISTS public;
`)
assert.NoError(t, err)
})
}
}
func Test_CreateReadOnlyUser_PublicSchemaExistsButNoPermissions_ReturnsError(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version string
port string
}{
{"PostgreSQL 12", "12", env.TestPostgres12Port},
{"PostgreSQL 13", "13", env.TestPostgres13Port},
{"PostgreSQL 14", "14", env.TestPostgres14Port},
{"PostgreSQL 15", "15", env.TestPostgres15Port},
{"PostgreSQL 16", "16", env.TestPostgres16Port},
{"PostgreSQL 17", "17", env.TestPostgres17Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToPostgresContainer(t, tc.port)
defer container.DB.Close()
limitedAdminUsername := fmt.Sprintf("limited_admin_%s", uuid.New().String()[:8])
limitedAdminPassword := "limited_password_123"
_, err := container.DB.Exec(`
CREATE SCHEMA IF NOT EXISTS public;
DROP TABLE IF EXISTS public.permission_test_table CASCADE;
CREATE TABLE public.permission_test_table (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO public.permission_test_table (data) VALUES ('test1');
`)
assert.NoError(t, err)
_, err = container.DB.Exec(`GRANT CREATE ON SCHEMA public TO PUBLIC`)
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN CREATEROLE`,
limitedAdminUsername,
limitedAdminPassword,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
container.Database,
limitedAdminUsername,
))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(
fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, limitedAdminUsername),
)
_, _ = container.DB.Exec(
fmt.Sprintf(`DROP USER IF EXISTS "%s"`, limitedAdminUsername),
)
_, _ = container.DB.Exec(`REVOKE CREATE ON SCHEMA public FROM PUBLIC`)
}()
limitedAdminDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
limitedAdminUsername,
limitedAdminPassword,
container.Database,
)
limitedAdminConn, err := sqlx.Connect("postgres", limitedAdminDSN)
assert.NoError(t, err)
defer limitedAdminConn.Close()
pgModel := &PostgresqlDatabase{
Version: tools.GetPostgresqlVersionEnum(tc.version),
Host: container.Host,
Port: container.Port,
Username: limitedAdminUsername,
Password: limitedAdminPassword,
Database: &container.Database,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.Error(
t,
err,
"CreateReadOnlyUser should fail when admin lacks permissions to secure public schema",
)
if err != nil {
errorMsg := err.Error()
hasExpectedError := strings.Contains(
errorMsg,
"failed to revoke CREATE from PUBLIC on existing public schema",
) ||
strings.Contains(errorMsg, "permission denied for schema public") ||
strings.Contains(errorMsg, "failed to grant")
assert.True(
t,
hasExpectedError,
"Error should indicate permission issues with public schema, got: %s",
errorMsg,
)
}
assert.Empty(t, username)
assert.Empty(t, password)
})
}
}
func Test_Validate_WhenLocalhostAndDatabasus_ReturnsError(t *testing.T) {
testCases := []struct {
name string
@@ -981,7 +1323,7 @@ func connectToPostgresContainer(t *testing.T, port string) *PostgresContainer {
dbName := "testdb"
password := "testpassword"
username := "testuser"
host := "localhost"
host := config.GetEnv().TestLocalhost
portInt, err := strconv.Atoi(port)
assert.NoError(t, err)

View File

@@ -1,6 +1,9 @@
package databases
import (
"sync"
"sync/atomic"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/notifiers"
users_services "databasus-backend/internal/features/users/services"
@@ -37,7 +40,22 @@ func GetDatabaseController() *DatabaseController {
return databaseController
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(databaseService)
notifiers.GetNotifierService().SetNotifierDatabaseCounter(databaseService)
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(databaseService)
notifiers.GetNotifierService().SetNotifierDatabaseCounter(databaseService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -1,6 +1,7 @@
package databases
import (
"context"
"databasus-backend/internal/features/databases/databases/mariadb"
"databasus-backend/internal/features/databases/databases/mongodb"
"databasus-backend/internal/features/databases/databases/mysql"
@@ -84,6 +85,25 @@ func (d *Database) TestConnection(
return d.getSpecificDatabase().TestConnection(logger, encryptor, d.ID)
}
func (d *Database) IsUserReadOnly(
ctx context.Context,
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
) (bool, []string, error) {
switch d.Type {
case DatabaseTypePostgres:
return d.Postgresql.IsUserReadOnly(ctx, logger, encryptor, d.ID)
case DatabaseTypeMysql:
return d.Mysql.IsUserReadOnly(ctx, logger, encryptor, d.ID)
case DatabaseTypeMariadb:
return d.Mariadb.IsUserReadOnly(ctx, logger, encryptor, d.ID)
case DatabaseTypeMongodb:
return d.Mongodb.IsUserReadOnly(ctx, logger, encryptor, d.ID)
default:
return false, nil, errors.New("read-only check not supported for this database type")
}
}
func (d *Database) HideSensitiveData() {
d.getSpecificDatabase().HideSensitiveData()
}

View File

@@ -7,6 +7,7 @@ import (
"log/slog"
"time"
"databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/databases/databases/mariadb"
"databasus-backend/internal/features/databases/databases/mongodb"
@@ -86,6 +87,23 @@ func (s *DatabaseService) CreateDatabase(
return nil, fmt.Errorf("failed to auto-detect database data: %w", err)
}
if config.GetEnv().IsCloud {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
isReadOnly, permissions, err := database.IsUserReadOnly(ctx, s.logger, s.fieldEncryptor)
if err != nil {
return nil, fmt.Errorf("failed to verify user permissions: %w", err)
}
if !isReadOnly {
return nil, fmt.Errorf(
"in cloud mode, only read-only database users are allowed (user has permissions: %v)",
permissions,
)
}
}
if err := database.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
return nil, fmt.Errorf("failed to encrypt sensitive fields: %w", err)
}
@@ -153,6 +171,27 @@ func (s *DatabaseService) UpdateDatabase(
return fmt.Errorf("failed to auto-detect database data: %w", err)
}
if config.GetEnv().IsCloud {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
isReadOnly, permissions, err := existingDatabase.IsUserReadOnly(
ctx,
s.logger,
s.fieldEncryptor,
)
if err != nil {
return fmt.Errorf("failed to verify user permissions: %w", err)
}
if !isReadOnly {
return fmt.Errorf(
"in cloud mode, only read-only database users are allowed (user has permissions: %v)",
permissions,
)
}
}
if err := existingDatabase.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
return fmt.Errorf("failed to encrypt sensitive fields: %w", err)
}
@@ -649,38 +688,7 @@ func (s *DatabaseService) IsUserReadOnly(
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
switch usingDatabase.Type {
case DatabaseTypePostgres:
return usingDatabase.Postgresql.IsUserReadOnly(
ctx,
s.logger,
s.fieldEncryptor,
usingDatabase.ID,
)
case DatabaseTypeMysql:
return usingDatabase.Mysql.IsUserReadOnly(
ctx,
s.logger,
s.fieldEncryptor,
usingDatabase.ID,
)
case DatabaseTypeMariadb:
return usingDatabase.Mariadb.IsUserReadOnly(
ctx,
s.logger,
s.fieldEncryptor,
usingDatabase.ID,
)
case DatabaseTypeMongodb:
return usingDatabase.Mongodb.IsUserReadOnly(
ctx,
s.logger,
s.fieldEncryptor,
usingDatabase.ID,
)
default:
return false, nil, errors.New("read-only check not supported for this database type")
}
return usingDatabase.IsUserReadOnly(ctx, s.logger, s.fieldEncryptor)
}
func (s *DatabaseService) CreateReadOnlyUser(

View File

@@ -10,6 +10,7 @@ import (
"databasus-backend/internal/features/databases/databases/postgresql"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
"databasus-backend/internal/storage"
"databasus-backend/internal/util/tools"
"github.com/google/uuid"
@@ -25,7 +26,7 @@ func GetTestPostgresConfig() *postgresql.PostgresqlDatabase {
testDbName := "testdb"
return &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",
@@ -48,7 +49,7 @@ func GetTestMariadbConfig() *mariadb.MariadbDatabase {
testDbName := "testdb"
return &mariadb.MariadbDatabase{
Version: tools.MariadbVersion1011,
Host: "localhost",
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",
@@ -69,7 +70,7 @@ func GetTestMongodbConfig() *mongodb.MongodbDatabase {
return &mongodb.MongodbDatabase{
Version: tools.MongodbVersion7,
Host: "localhost",
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "root",
Password: "rootpassword",
@@ -104,6 +105,19 @@ func CreateTestDatabase(
}
func RemoveTestDatabase(database *Database) {
// Delete backups and backup configs associated with this database
// We hardcode SQL here because we cannot call backups feature due to DI inversion
// (databases package cannot import backups package as backups already imports databases)
db := storage.GetDb()
if err := db.Exec("DELETE FROM backups WHERE database_id = ?", database.ID).Error; err != nil {
panic(fmt.Sprintf("failed to delete backups: %v", err))
}
if err := db.Exec("DELETE FROM backup_configs WHERE database_id = ?", database.ID).Error; err != nil {
panic(fmt.Sprintf("failed to delete backup config: %v", err))
}
err := databaseRepository.Delete(database.ID)
if err != nil {
panic(err)

View File

@@ -12,6 +12,15 @@ import (
type DiskService struct{}
func (s *DiskService) GetDiskUsage() (*DiskUsage, error) {
if config.GetEnv().IsCloud {
return &DiskUsage{
Platform: PlatformLinux,
TotalSpaceBytes: 100,
UsedSpaceBytes: 0,
FreeSpaceBytes: 100,
}, nil
}
platform := s.detectPlatform()
var path string

View File

@@ -0,0 +1,22 @@
package email
import (
"databasus-backend/internal/config"
"databasus-backend/internal/util/logger"
)
var env = config.GetEnv()
var log = logger.GetLogger()
var emailSMTPSender = &EmailSMTPSender{
log,
env.SMTPHost,
env.SMTPPort,
env.SMTPUser,
env.SMTPPassword,
env.SMTPHost != "" && env.SMTPPort != 0,
}
func GetEmailSMTPSender() *EmailSMTPSender {
return emailSMTPSender
}

View File

@@ -0,0 +1,245 @@
package email
import (
"crypto/tls"
"fmt"
"log/slog"
"mime"
"net"
"net/smtp"
"time"
)
const (
ImplicitTLSPort = 465
DefaultTimeout = 5 * time.Second
DefaultHelloName = "localhost"
MIMETypeHTML = "text/html"
MIMECharsetUTF8 = "UTF-8"
)
type EmailSMTPSender struct {
logger *slog.Logger
smtpHost string
smtpPort int
smtpUser string
smtpPassword string
isConfigured bool
}
func (s *EmailSMTPSender) SendEmail(to, subject, body string) error {
if !s.isConfigured {
s.logger.Warn("Skipping email send, SMTP not initialized", "to", to, "subject", subject)
return nil
}
from := s.smtpUser
if from == "" {
from = "noreply@" + s.smtpHost
}
emailContent := s.buildEmailContent(to, subject, body, from)
isAuthRequired := s.smtpUser != "" && s.smtpPassword != ""
if s.smtpPort == ImplicitTLSPort {
return s.sendImplicitTLS(to, from, emailContent, isAuthRequired)
}
return s.sendStartTLS(to, from, emailContent, isAuthRequired)
}
func (s *EmailSMTPSender) buildEmailContent(to, subject, body, from string) []byte {
// Encode Subject header using RFC 2047 to avoid SMTPUTF8 requirement
encodedSubject := encodeRFC2047(subject)
subjectHeader := fmt.Sprintf("Subject: %s\r\n", encodedSubject)
dateHeader := fmt.Sprintf("Date: %s\r\n", time.Now().UTC().Format(time.RFC1123Z))
mimeHeaders := fmt.Sprintf(
"MIME-version: 1.0;\nContent-Type: %s; charset=\"%s\";\n\n",
MIMETypeHTML,
MIMECharsetUTF8,
)
// Encode From header display name if it contains non-ASCII
encodedFrom := encodeRFC2047(from)
fromHeader := fmt.Sprintf("From: %s\r\n", encodedFrom)
toHeader := fmt.Sprintf("To: %s\r\n", to)
return []byte(fromHeader + toHeader + subjectHeader + dateHeader + mimeHeaders + body)
}
func (s *EmailSMTPSender) sendImplicitTLS(
to, from string,
emailContent []byte,
isAuthRequired bool,
) error {
createClient := func() (*smtp.Client, func(), error) {
return s.createImplicitTLSClient()
}
client, cleanup, err := s.authenticateWithRetry(createClient, isAuthRequired)
if err != nil {
return err
}
defer cleanup()
return s.sendEmail(client, to, from, emailContent)
}
func (s *EmailSMTPSender) sendStartTLS(
to, from string,
emailContent []byte,
isAuthRequired bool,
) error {
createClient := func() (*smtp.Client, func(), error) {
return s.createStartTLSClient()
}
client, cleanup, err := s.authenticateWithRetry(createClient, isAuthRequired)
if err != nil {
return err
}
defer cleanup()
return s.sendEmail(client, to, from, emailContent)
}
func (s *EmailSMTPSender) createImplicitTLSClient() (*smtp.Client, func(), error) {
addr := net.JoinHostPort(s.smtpHost, fmt.Sprintf("%d", s.smtpPort))
tlsConfig := &tls.Config{ServerName: s.smtpHost}
dialer := &net.Dialer{Timeout: DefaultTimeout}
conn, err := tls.DialWithDialer(dialer, "tcp", addr, tlsConfig)
if err != nil {
return nil, nil, fmt.Errorf("failed to connect to SMTP server: %w", err)
}
client, err := smtp.NewClient(conn, s.smtpHost)
if err != nil {
_ = conn.Close()
return nil, nil, fmt.Errorf("failed to create SMTP client: %w", err)
}
return client, func() { _ = client.Quit() }, nil
}
func (s *EmailSMTPSender) createStartTLSClient() (*smtp.Client, func(), error) {
addr := net.JoinHostPort(s.smtpHost, fmt.Sprintf("%d", s.smtpPort))
dialer := &net.Dialer{Timeout: DefaultTimeout}
conn, err := dialer.Dial("tcp", addr)
if err != nil {
return nil, nil, fmt.Errorf("failed to connect to SMTP server: %w", err)
}
client, err := smtp.NewClient(conn, s.smtpHost)
if err != nil {
_ = conn.Close()
return nil, nil, fmt.Errorf("failed to create SMTP client: %w", err)
}
if err := client.Hello(DefaultHelloName); err != nil {
_ = client.Quit()
_ = conn.Close()
return nil, nil, fmt.Errorf("SMTP hello failed: %w", err)
}
if ok, _ := client.Extension("STARTTLS"); ok {
if err := client.StartTLS(&tls.Config{ServerName: s.smtpHost}); err != nil {
_ = client.Quit()
_ = conn.Close()
return nil, nil, fmt.Errorf("STARTTLS failed: %w", err)
}
}
return client, func() { _ = client.Quit() }, nil
}
func (s *EmailSMTPSender) authenticateWithRetry(
createClient func() (*smtp.Client, func(), error),
isAuthRequired bool,
) (*smtp.Client, func(), error) {
client, cleanup, err := createClient()
if err != nil {
return nil, nil, err
}
if !isAuthRequired {
return client, cleanup, nil
}
// Try PLAIN auth first
plainAuth := smtp.PlainAuth("", s.smtpUser, s.smtpPassword, s.smtpHost)
if err := client.Auth(plainAuth); err == nil {
return client, cleanup, nil
}
// PLAIN auth failed, connection may be closed - recreate and try LOGIN auth
cleanup()
client, cleanup, err = createClient()
if err != nil {
return nil, nil, err
}
loginAuth := &loginAuth{username: s.smtpUser, password: s.smtpPassword}
if err := client.Auth(loginAuth); err != nil {
cleanup()
return nil, nil, fmt.Errorf("SMTP authentication failed: %w", err)
}
return client, cleanup, nil
}
func (s *EmailSMTPSender) sendEmail(client *smtp.Client, to, from string, content []byte) error {
if err := client.Mail(from); err != nil {
return fmt.Errorf("failed to set sender: %w", err)
}
if err := client.Rcpt(to); err != nil {
return fmt.Errorf("failed to set recipient: %w", err)
}
writer, err := client.Data()
if err != nil {
return fmt.Errorf("failed to get data writer: %w", err)
}
if _, err = writer.Write(content); err != nil {
return fmt.Errorf("failed to write email content: %w", err)
}
if err = writer.Close(); err != nil {
return fmt.Errorf("failed to close data writer: %w", err)
}
return nil
}
func encodeRFC2047(s string) string {
return mime.QEncoding.Encode("UTF-8", s)
}
type loginAuth struct {
username string
password string
}
func (a *loginAuth) Start(server *smtp.ServerInfo) (string, []byte, error) {
return "LOGIN", []byte{}, nil
}
func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) {
if more {
switch string(fromServer) {
case "Username:", "User Name\x00":
return []byte(a.username), nil
case "Password:", "Password\x00":
return []byte(a.password), nil
default:
return []byte(a.username), nil
}
}
return nil, nil
}

View File

@@ -2,30 +2,47 @@ package healthcheck_attempt
import (
"context"
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
)
type HealthcheckAttemptBackgroundService struct {
healthcheckConfigService *healthcheck_config.HealthcheckConfigService
checkDatabaseHealthUseCase *CheckDatabaseHealthUseCase
logger *slog.Logger
runOnce sync.Once
hasRun atomic.Bool
}
func (s *HealthcheckAttemptBackgroundService) Run(ctx context.Context) {
// first healthcheck immediately
s.checkDatabases()
wasAlreadyRun := s.hasRun.Load()
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.checkDatabases()
s.runOnce.Do(func() {
s.hasRun.Store(true)
// first healthcheck immediately
s.checkDatabases()
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.checkDatabases()
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
}

View File

@@ -144,6 +144,10 @@ func Test_GetAttemptsByDatabase_PermissionsEnforced(t *testing.T) {
)
assert.Contains(t, string(testResp.Body), "forbidden")
}
// Cleanup
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
@@ -181,6 +185,10 @@ func Test_GetAttemptsByDatabase_FiltersByAfterDate(t *testing.T) {
for _, attempt := range response {
assert.True(t, attempt.CreatedAt.After(afterDate) || attempt.CreatedAt.Equal(afterDate))
}
// Cleanup
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_GetAttemptsByDatabase_ReturnsEmptyListForNewDatabase(t *testing.T) {
@@ -201,6 +209,10 @@ func Test_GetAttemptsByDatabase_ReturnsEmptyListForNewDatabase(t *testing.T) {
)
assert.Equal(t, 0, len(response))
// Cleanup
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func createTestDatabaseViaAPI(

View File

@@ -1,6 +1,9 @@
package healthcheck_attempt
import (
"sync"
"sync/atomic"
"databasus-backend/internal/features/databases"
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
"databasus-backend/internal/features/notifiers"
@@ -22,9 +25,11 @@ var checkDatabaseHealthUseCase = &CheckDatabaseHealthUseCase{
}
var healthcheckAttemptBackgroundService = &HealthcheckAttemptBackgroundService{
healthcheck_config.GetHealthcheckConfigService(),
checkDatabaseHealthUseCase,
logger.GetLogger(),
healthcheckConfigService: healthcheck_config.GetHealthcheckConfigService(),
checkDatabaseHealthUseCase: checkDatabaseHealthUseCase,
logger: logger.GetLogger(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
var healthcheckAttemptController = &HealthcheckAttemptController{
healthcheckAttemptService,

View File

@@ -130,6 +130,10 @@ func Test_SaveHealthcheckConfig_PermissionsEnforced(t *testing.T) {
)
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
// Cleanup
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
@@ -162,6 +166,10 @@ func Test_SaveHealthcheckConfig_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t
)
assert.Contains(t, string(testResp.Body), "insufficient permissions")
// Cleanup
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_GetHealthcheckConfig_PermissionsEnforced(t *testing.T) {
@@ -268,6 +276,10 @@ func Test_GetHealthcheckConfig_PermissionsEnforced(t *testing.T) {
)
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
// Cleanup
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
@@ -295,6 +307,10 @@ func Test_GetHealthcheckConfig_ReturnsDefaultConfigForNewDatabase(t *testing.T)
assert.Equal(t, 1, response.IntervalMinutes)
assert.Equal(t, 3, response.AttemptsBeforeConcideredAsDown)
assert.Equal(t, 7, response.StoreAttemptsDays)
// Cleanup
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func createTestDatabaseViaAPI(

View File

@@ -1,6 +1,9 @@
package healthcheck_config
import (
"sync"
"sync/atomic"
"databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/databases"
workspaces_services "databasus-backend/internal/features/workspaces/services"
@@ -27,8 +30,23 @@ func GetHealthcheckConfigController() *HealthcheckConfigController {
return healthcheckConfigController
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
databases.
GetDatabaseService().
AddDbCreationListener(healthcheckConfigService)
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
databases.
GetDatabaseService().
AddDbCreationListener(healthcheckConfigService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -1,6 +1,9 @@
package notifiers
import (
"sync"
"sync/atomic"
audit_logs "databasus-backend/internal/features/audit_logs"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/encryption"
@@ -32,6 +35,22 @@ func GetNotifierService() *NotifierService {
func GetNotifierRepository() *NotifierRepository {
return notifierRepository
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(notifierService)
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(notifierService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"log/slog"
"mime"
"net"
"net/smtp"
"time"
@@ -115,16 +116,35 @@ func (e *EmailNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor
return nil
}
// encodeRFC2047 encodes a string using RFC 2047 MIME encoding for email headers
// This ensures compatibility with SMTP servers that don't support SMTPUTF8
func encodeRFC2047(s string) string {
// mime.QEncoding handles UTF-8 → =?UTF-8?Q?...?= encoding
// This allows non-ASCII characters (emojis, accents, etc.) in email headers
// while maintaining compatibility with all SMTP servers
return mime.QEncoding.Encode("UTF-8", s)
}
func (e *EmailNotifier) buildEmailContent(heading, message, from string) []byte {
subject := fmt.Sprintf("Subject: %s\r\n", heading)
mime := fmt.Sprintf(
// Encode Subject header using RFC 2047 to avoid SMTPUTF8 requirement
// This ensures compatibility with SMTP servers that don't support SMTPUTF8
encodedSubject := encodeRFC2047(heading)
subject := fmt.Sprintf("Subject: %s\r\n", encodedSubject)
dateHeader := fmt.Sprintf("Date: %s\r\n", time.Now().UTC().Format(time.RFC1123Z))
mimeHeaders := fmt.Sprintf(
"MIME-version: 1.0;\nContent-Type: %s; charset=\"%s\";\n\n",
MIMETypeHTML,
MIMECharsetUTF8,
)
fromHeader := fmt.Sprintf("From: %s\r\n", from)
// Encode From header display name if it contains non-ASCII
encodedFrom := encodeRFC2047(from)
fromHeader := fmt.Sprintf("From: %s\r\n", encodedFrom)
toHeader := fmt.Sprintf("To: %s\r\n", e.TargetEmail)
return []byte(fromHeader + toHeader + subject + mime + message)
return []byte(fromHeader + toHeader + subject + dateHeader + mimeHeaders + message)
}
func (e *EmailNotifier) sendImplicitTLS(

View File

@@ -0,0 +1,20 @@
package plans
import (
"databasus-backend/internal/util/logger"
)
var databasePlanRepository = &DatabasePlanRepository{}
var databasePlanService = &DatabasePlanService{
databasePlanRepository,
logger.GetLogger(),
}
func GetDatabasePlanService() *DatabasePlanService {
return databasePlanService
}
func GetDatabasePlanRepository() *DatabasePlanRepository {
return databasePlanRepository
}

View File

@@ -0,0 +1,19 @@
package plans
import (
"databasus-backend/internal/util/period"
"github.com/google/uuid"
)
type DatabasePlan struct {
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;primaryKey;not null"`
MaxBackupSizeMB int64 `json:"maxBackupSizeMb" gorm:"column:max_backup_size_mb;type:int;not null"`
MaxBackupsTotalSizeMB int64 `json:"maxBackupsTotalSizeMb" gorm:"column:max_backups_total_size_mb;type:int;not null"`
MaxStoragePeriod period.Period `json:"maxStoragePeriod" gorm:"column:max_storage_period;type:text;not null"`
}
func (p *DatabasePlan) TableName() string {
return "database_plans"
}

View File

@@ -0,0 +1,27 @@
package plans
import (
"databasus-backend/internal/storage"
"github.com/google/uuid"
)
type DatabasePlanRepository struct{}
func (r *DatabasePlanRepository) GetDatabasePlan(databaseID uuid.UUID) (*DatabasePlan, error) {
var databasePlan DatabasePlan
if err := storage.GetDb().Where("database_id = ?", databaseID).First(&databasePlan).Error; err != nil {
if err.Error() == "record not found" {
return nil, nil
}
return nil, err
}
return &databasePlan, nil
}
func (r *DatabasePlanRepository) CreateDatabasePlan(databasePlan *DatabasePlan) error {
return storage.GetDb().Create(&databasePlan).Error
}

View File

@@ -0,0 +1,67 @@
package plans
import (
"databasus-backend/internal/config"
"databasus-backend/internal/util/period"
"log/slog"
"github.com/google/uuid"
)
type DatabasePlanService struct {
databasePlanRepository *DatabasePlanRepository
logger *slog.Logger
}
func (s *DatabasePlanService) GetDatabasePlan(databaseID uuid.UUID) (*DatabasePlan, error) {
plan, err := s.databasePlanRepository.GetDatabasePlan(databaseID)
if err != nil {
return nil, err
}
if plan == nil {
s.logger.Info("no database plan found, creating default plan", "databaseID", databaseID)
defaultPlan := s.createDefaultDatabasePlan(databaseID)
err := s.databasePlanRepository.CreateDatabasePlan(defaultPlan)
if err != nil {
s.logger.Error("failed to create default database plan", "error", err)
return nil, err
}
return defaultPlan, nil
}
return plan, nil
}
func (s *DatabasePlanService) createDefaultDatabasePlan(databaseID uuid.UUID) *DatabasePlan {
var plan DatabasePlan
isCloud := config.GetEnv().IsCloud
if isCloud {
s.logger.Info("creating default database plan for cloud", "databaseID", databaseID)
// for playground we set limited storages enough to test,
// but not too expensive to provide it for Databasus
plan = DatabasePlan{
DatabaseID: databaseID,
MaxBackupSizeMB: 100, // ~ 1.5GB database
MaxBackupsTotalSizeMB: 4000, // ~ 30 daily backups + 10 manual backups
MaxStoragePeriod: period.PeriodWeek,
}
} else {
s.logger.Info("creating default database plan for self hosted", "databaseID", databaseID)
// by default - everything is unlimited in self hosted mode
plan = DatabasePlan{
DatabaseID: databaseID,
MaxBackupSizeMB: 0,
MaxBackupsTotalSizeMB: 0,
MaxStoragePeriod: period.PeriodForever,
}
}
return &plan
}

View File

@@ -16,6 +16,7 @@ type RestoreController struct {
func (c *RestoreController) RegisterRoutes(router *gin.RouterGroup) {
router.GET("/restores/:backupId", c.GetRestores)
router.POST("/restores/:backupId/restore", c.RestoreBackup)
router.POST("/restores/cancel/:restoreId", c.CancelRestore)
}
// GetRestores
@@ -85,3 +86,33 @@ func (c *RestoreController) RestoreBackup(ctx *gin.Context) {
ctx.JSON(http.StatusOK, gin.H{"message": "restore started successfully"})
}
// CancelRestore
// @Summary Cancel an in-progress restore
// @Description Cancel a restore that is currently in progress
// @Tags restores
// @Param restoreId path string true "Restore ID"
// @Success 204
// @Failure 400
// @Failure 401
// @Router /restores/cancel/{restoreId} [post]
func (c *RestoreController) CancelRestore(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
restoreID, err := uuid.Parse(ctx.Param("restoreId"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid restore ID"})
return
}
if err := c.restoreService.CancelRestore(user, restoreID); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.Status(http.StatusNoContent)
}

View File

@@ -16,22 +16,27 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
env_config "databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/mysql"
"databasus-backend/internal/features/databases/databases/postgresql"
"databasus-backend/internal/features/notifiers"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/restores/restoring"
"databasus-backend/internal/features/storages"
local_storage "databasus-backend/internal/features/storages/models/local"
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
users_dto "databasus-backend/internal/features/users/dto"
users_enums "databasus-backend/internal/features/users/enums"
users_services "databasus-backend/internal/features/users/services"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_models "databasus-backend/internal/features/workspaces/models"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
cache_utils "databasus-backend/internal/util/cache"
util_encryption "databasus-backend/internal/util/encryption"
test_utils "databasus-backend/internal/util/testing"
"databasus-backend/internal/util/tools"
@@ -41,8 +46,10 @@ func Test_GetRestores_WhenUserIsWorkspaceMember_RestoresReturned(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
var restores []*restores_core.Restore
test_utils.MakeGetRequestAndUnmarshal(
@@ -63,8 +70,10 @@ func Test_GetRestores_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing.T
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
@@ -83,8 +92,10 @@ func Test_GetRestores_WhenUserIsGlobalAdmin_RestoresReturned(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
@@ -109,13 +120,15 @@ func Test_RestoreBackup_WhenUserIsWorkspaceMember_RestoreInitiated(t *testing.T)
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Host: env_config.GetEnv().TestLocalhost,
Port: 5432,
Username: "postgres",
Password: "postgres",
@@ -138,15 +151,17 @@ func Test_RestoreBackup_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Host: env_config.GetEnv().TestLocalhost,
Port: 5432,
Username: "postgres",
Password: "postgres",
@@ -173,13 +188,15 @@ func Test_RestoreBackup_WithIsExcludeExtensions_FlagPassedCorrectly(t *testing.T
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Host: env_config.GetEnv().TestLocalhost,
Port: 5432,
Username: "postgres",
Password: "postgres",
@@ -207,13 +224,15 @@ func Test_RestoreBackup_AuditLogWritten(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Host: env_config.GetEnv().TestLocalhost,
Port: 5432,
Username: "postgres",
Password: "postgres",
@@ -291,16 +310,20 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
var database *databases.Database
var backup *backups_core.Backup
var storage *storages.Storage
var request restores_core.RestoreBackupRequest
if tc.dbType == databases.DatabaseTypePostgres {
_, backup = createTestDatabaseWithBackupForRestore(workspace, owner, router)
database, backup = createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
request = restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Host: env_config.GetEnv().TestLocalhost,
Port: 5432,
Username: "postgres",
Password: "postgres",
@@ -314,7 +337,16 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
owner.Token,
router,
)
storage := createTestStorage(workspace.ID)
database = mysqlDB
storage = createTestStorage(workspace.ID)
defer func() {
// Cleanup in dependency order: backup -> database -> storage
cleanupBackup(backup)
databases.RemoveTestDatabase(mysqlDB)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
}()
configService := backups_config.GetBackupConfigService()
config, err := configService.GetBackupConfigByDbId(mysqlDB.ID)
@@ -327,10 +359,11 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
assert.NoError(t, err)
backup = createTestBackup(mysqlDB, owner)
request = restores_core.RestoreBackupRequest{
MysqlDatabase: &mysql.MysqlDatabase{
Version: tools.MysqlVersion80,
Host: "localhost",
Host: env_config.GetEnv().TestLocalhost,
Port: 3306,
Username: "root",
Password: "password",
@@ -370,6 +403,187 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
}
}
func Test_CancelRestore_InProgressRestore_SuccessfullyCancelled(t *testing.T) {
cache_utils.ClearAllCache()
tasks_cancellation.SetupDependencies()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := createTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backupRepo := backups_core.BackupRepository{}
backups, _ := backupRepo.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepo.DeleteByID(backup.ID)
}
restoreRepo := restores_core.RestoreRepository{}
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
for _, restore := range restores {
restoreRepo.DeleteByID(restore.ID)
}
restores, _ = restoreRepo.FindByStatus(restores_core.RestoreStatusCanceled)
for _, restore := range restores {
restoreRepo.DeleteByID(restore.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
cache_utils.ClearAllCache()
}()
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
backup := backups.CreateTestBackup(database.ID, storage.ID)
mockUsecase := &restoring.MockBlockingRestoreUsecase{
StartedChan: make(chan bool, 1),
}
restorerNode := restoring.CreateTestRestorerNodeWithUsecase(mockUsecase)
cancelNode := restoring.StartRestorerNodeForTest(t, restorerNode)
defer cancelNode()
time.Sleep(200 * time.Millisecond)
restoreRequest := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: env_config.GetEnv().TestLocalhost,
Port: 5432,
Username: "postgres",
Password: "postgres",
},
}
var restoreResponse map[string]interface{}
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
"Bearer "+user.Token,
restoreRequest,
http.StatusOK,
&restoreResponse,
)
select {
case <-mockUsecase.StartedChan:
t.Log("Restore started and is blocking")
case <-time.After(2 * time.Second):
t.Fatal("Restore did not start within timeout")
}
restoreRepo := &restores_core.RestoreRepository{}
restores, err := restoreRepo.FindByBackupID(backup.ID)
assert.NoError(t, err)
assert.Greater(t, len(restores), 0, "At least one restore should exist")
var restoreID uuid.UUID
for _, r := range restores {
if r.Status == restores_core.RestoreStatusInProgress {
restoreID = r.ID
break
}
}
assert.NotEqual(t, uuid.Nil, restoreID, "Should find an in-progress restore")
resp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/cancel/%s", restoreID.String()),
"Bearer "+user.Token,
nil,
http.StatusNoContent,
)
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
deadline := time.Now().UTC().Add(3 * time.Second)
var restore *restores_core.Restore
for time.Now().UTC().Before(deadline) {
restore, err = restoreRepo.FindByID(restoreID)
assert.NoError(t, err)
if restore.Status == restores_core.RestoreStatusCanceled {
break
}
time.Sleep(100 * time.Millisecond)
}
assert.Equal(t, restores_core.RestoreStatusCanceled, restore.Status)
auditLogService := audit_logs.GetAuditLogService()
auditLogs, err := auditLogService.GetWorkspaceAuditLogs(
workspace.ID,
&audit_logs.GetAuditLogsRequest{Limit: 100, Offset: 0},
)
assert.NoError(t, err)
foundCancelLog := false
for _, log := range auditLogs.AuditLogs {
if strings.Contains(log.Message, "Restore cancelled") &&
strings.Contains(log.Message, database.Name) {
foundCancelLog = true
break
}
}
assert.True(t, foundCancelLog, "Cancel audit log should be created")
time.Sleep(200 * time.Millisecond)
}
func Test_RestoreBackup_WithParallelRestoreInProgress_ReturnsError(t *testing.T) {
router := createTestRouter()
_, cleanup := SetupMockRestoreNode(t)
defer cleanup()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: env_config.GetEnv().TestLocalhost,
Port: 5432,
Username: "postgres",
Password: "postgres",
},
}
testResp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
"Bearer "+owner.Token,
request,
http.StatusOK,
)
assert.Contains(t, string(testResp.Body), "restore started successfully")
testResp2 := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
"Bearer "+owner.Token,
request,
http.StatusBadRequest,
)
assert.Contains(t, string(testResp2.Body), "another restore is already in progress")
}
func createTestRouter() *gin.Engine {
return CreateTestRouter()
}
@@ -442,7 +656,7 @@ func createTestMySQLDatabase(
token string,
router *gin.Engine,
) *databases.Database {
env := config.GetEnv()
env := env_config.GetEnv()
portStr := env.TestMysql80Port
if portStr == "" {
portStr = "33080"
@@ -460,7 +674,7 @@ func createTestMySQLDatabase(
Type: databases.DatabaseTypeMysql,
Mysql: &mysql.MysqlDatabase{
Version: tools.MysqlVersion80,
Host: "localhost",
Host: env_config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",
@@ -557,3 +771,22 @@ func createTestBackup(
return backup
}
func cleanupDatabaseWithBackup(database *databases.Database, backup *backups_core.Backup) {
// Clean up in reverse dependency order
cleanupBackup(backup)
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
// Clean up storage last (after database and backup are removed)
configService := backups_config.GetBackupConfigService()
config, err := configService.GetBackupConfigByDbId(database.ID)
if err == nil && config.StorageID != nil {
storages.RemoveTestStorage(*config.StorageID)
}
}
func cleanupBackup(backup *backups_core.Backup) {
repo := &backups_core.BackupRepository{}
repo.DeleteByID(backup.ID)
}

View File

@@ -6,4 +6,5 @@ const (
RestoreStatusInProgress RestoreStatus = "IN_PROGRESS"
RestoreStatusCompleted RestoreStatus = "COMPLETED"
RestoreStatusFailed RestoreStatus = "FAILED"
RestoreStatusCanceled RestoreStatus = "CANCELED"
)

View File

@@ -1,6 +1,8 @@
package restores_core
import (
"context"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -9,6 +11,7 @@ import (
type RestoreBackupUsecase interface {
Execute(
ctx context.Context,
backupConfig *backups_config.BackupConfig,
restore Restore,
originalDB *databases.Database,

View File

@@ -68,6 +68,24 @@ func (r *RestoreRepository) FindByStatus(status RestoreStatus) ([]*Restore, erro
return restores, nil
}
func (r *RestoreRepository) FindInProgressRestoresByDatabaseID(
databaseID uuid.UUID,
) ([]*Restore, error) {
var restores []*Restore
if err := storage.
GetDb().
Preload("Backup").
Joins("JOIN backups ON backups.id = restores.backup_id").
Where("backups.database_id = ? AND restores.status = ?", databaseID, RestoreStatusInProgress).
Order("restores.created_at DESC").
Find(&restores).Error; err != nil {
return nil, err
}
return restores, nil
}
func (r *RestoreRepository) DeleteByID(id uuid.UUID) error {
return storage.GetDb().Delete(&Restore{}, "id = ?", id).Error
}

View File

@@ -1,14 +1,19 @@
package restores
import (
"sync"
"sync/atomic"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
"databasus-backend/internal/features/backups/backups/backuping"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/disk"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/restores/usecases"
"databasus-backend/internal/features/storages"
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
@@ -27,6 +32,7 @@ var restoreService = &RestoreService{
audit_logs.GetAuditLogService(),
encryption.GetFieldEncryptor(),
disk.GetDiskService(),
tasks_cancellation.GetTaskCancelManager(),
}
var restoreController = &RestoreController{
restoreService,
@@ -36,6 +42,22 @@ func GetRestoreController() *RestoreController {
return restoreController
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
backups.GetBackupService().AddBackupRemoveListener(restoreService)
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
backups.GetBackupService().AddBackupRemoveListener(restoreService)
backuping.GetBackupCleaner().AddBackupRemoveListener(restoreService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -1,6 +1,8 @@
package restoring
import (
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
@@ -11,6 +13,7 @@ import (
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/restores/usecases"
"databasus-backend/internal/features/storages"
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
@@ -19,11 +22,13 @@ import (
var restoreRepository = &restores_core.RestoreRepository{}
var restoreNodesRegistry = &RestoreNodesRegistry{
cache_utils.GetValkeyClient(),
logger.GetLogger(),
cache_utils.DefaultCacheTimeout,
cache_utils.NewPubSubManager(),
cache_utils.NewPubSubManager(),
client: cache_utils.GetValkeyClient(),
logger: logger.GetLogger(),
timeout: cache_utils.DefaultCacheTimeout,
pubsubRestores: cache_utils.NewPubSubManager(),
pubsubCompletions: cache_utils.NewPubSubManager(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
var restoreDatabaseCache = cache_utils.NewCacheUtil[RestoreDatabaseCache](
@@ -31,19 +36,24 @@ var restoreDatabaseCache = cache_utils.NewCacheUtil[RestoreDatabaseCache](
"restore_db:",
)
var restoreCancelManager = tasks_cancellation.GetTaskCancelManager()
var restorerNode = &RestorerNode{
uuid.New(),
databases.GetDatabaseService(),
backups.GetBackupService(),
encryption.GetFieldEncryptor(),
restoreRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
restoreNodesRegistry,
logger.GetLogger(),
usecases.GetRestoreBackupUsecase(),
restoreDatabaseCache,
time.Time{},
nodeID: uuid.New(),
databaseService: databases.GetDatabaseService(),
backupService: backups.GetBackupService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
restoreRepository: restoreRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
restoreNodesRegistry: restoreNodesRegistry,
logger: logger.GetLogger(),
restoreBackupUsecase: usecases.GetRestoreBackupUsecase(),
cacheUtil: restoreDatabaseCache,
restoreCancelManager: restoreCancelManager,
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
var restoresScheduler = &RestoresScheduler{
@@ -58,6 +68,8 @@ var restoresScheduler = &RestoresScheduler{
restorerNode: restorerNode,
cacheUtil: restoreDatabaseCache,
completionSubscriptionID: uuid.Nil,
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
func GetRestoresScheduler() *RestoresScheduler {

View File

@@ -1,6 +1,7 @@
package restoring
import (
"context"
"errors"
backups_core "databasus-backend/internal/features/backups/backups/core"
@@ -13,6 +14,7 @@ import (
type MockSuccessRestoreUsecase struct{}
func (uc *MockSuccessRestoreUsecase) Execute(
ctx context.Context,
backupConfig *backups_config.BackupConfig,
restore restores_core.Restore,
originalDB *databases.Database,
@@ -27,6 +29,7 @@ func (uc *MockSuccessRestoreUsecase) Execute(
type MockFailedRestoreUsecase struct{}
func (uc *MockFailedRestoreUsecase) Execute(
ctx context.Context,
backupConfig *backups_config.BackupConfig,
restore restores_core.Restore,
originalDB *databases.Database,
@@ -44,6 +47,7 @@ type MockCaptureCredentialsRestoreUsecase struct {
}
func (uc *MockCaptureCredentialsRestoreUsecase) Execute(
ctx context.Context,
backupConfig *backups_config.BackupConfig,
restore restores_core.Restore,
originalDB *databases.Database,
@@ -59,3 +63,26 @@ func (uc *MockCaptureCredentialsRestoreUsecase) Execute(
}
return errors.New("mock restore failed")
}
type MockBlockingRestoreUsecase struct {
StartedChan chan bool
}
func (uc *MockBlockingRestoreUsecase) Execute(
ctx context.Context,
backupConfig *backups_config.BackupConfig,
restore restores_core.Restore,
originalDB *databases.Database,
restoringToDB *databases.Database,
backup *backups_core.Backup,
storage *storages.Storage,
isExcludeExtensions bool,
) error {
if uc.StartedChan != nil {
uc.StartedChan <- true
}
<-ctx.Done()
return ctx.Err()
}

View File

@@ -6,6 +6,8 @@ import (
"fmt"
"log/slog"
"strings"
"sync"
"sync/atomic"
"time"
cache_utils "databasus-backend/internal/util/cache"
@@ -47,24 +49,37 @@ type RestoreNodesRegistry struct {
timeout time.Duration
pubsubRestores *cache_utils.PubSubManager
pubsubCompletions *cache_utils.PubSubManager
runOnce sync.Once
hasRun atomic.Bool
}
func (r *RestoreNodesRegistry) Run(ctx context.Context) {
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
}
wasAlreadyRun := r.hasRun.Load()
ticker := time.NewTicker(cleanupTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes", "error", err)
r.runOnce.Do(func() {
r.hasRun.Store(true)
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
}
ticker := time.NewTicker(cleanupTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes", "error", err)
}
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", r))
}
}

View File

@@ -4,6 +4,8 @@ import (
"context"
"encoding/json"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
@@ -594,11 +596,13 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) {
func createTestRegistry() *RestoreNodesRegistry {
return &RestoreNodesRegistry{
cache_utils.GetValkeyClient(),
logger.GetLogger(),
cache_utils.DefaultCacheTimeout,
cache_utils.NewPubSubManager(),
cache_utils.NewPubSubManager(),
client: cache_utils.GetValkeyClient(),
logger: logger.GetLogger(),
timeout: cache_utils.DefaultCacheTimeout,
pubsubRestores: cache_utils.NewPubSubManager(),
pubsubCompletions: cache_utils.NewPubSubManager(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}

View File

@@ -2,8 +2,12 @@ package restoring
import (
"context"
"errors"
"fmt"
"log/slog"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
@@ -14,6 +18,7 @@ import (
"databasus-backend/internal/features/databases"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
cache_utils "databasus-backend/internal/util/cache"
util_encryption "databasus-backend/internal/util/encryption"
)
@@ -36,70 +41,84 @@ type RestorerNode struct {
logger *slog.Logger
restoreBackupUsecase restores_core.RestoreBackupUsecase
cacheUtil *cache_utils.CacheUtil[RestoreDatabaseCache]
restoreCancelManager *tasks_cancellation.TaskCancelManager
lastHeartbeat time.Time
runOnce sync.Once
hasRun atomic.Bool
}
func (n *RestorerNode) Run(ctx context.Context) {
n.lastHeartbeat = time.Now().UTC()
wasAlreadyRun := n.hasRun.Load()
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
n.runOnce.Do(func() {
n.hasRun.Store(true)
restoreNode := RestoreNode{
ID: n.nodeID,
ThroughputMBs: throughputMBs,
}
n.lastHeartbeat = time.Now().UTC()
if err := n.restoreNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), restoreNode); err != nil {
n.logger.Error("Failed to register node in registry", "error", err)
panic(err)
}
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
restoreHandler := func(restoreID uuid.UUID, isCallNotifier bool) {
n.MakeRestore(restoreID)
if err := n.restoreNodesRegistry.PublishRestoreCompletion(n.nodeID, restoreID); err != nil {
n.logger.Error(
"Failed to publish restore completion",
"error",
err,
"restoreID",
restoreID,
)
restoreNode := RestoreNode{
ID: n.nodeID,
ThroughputMBs: throughputMBs,
}
}
err := n.restoreNodesRegistry.SubscribeNodeForRestoresAssignment(
n.nodeID,
restoreHandler,
)
if err != nil {
n.logger.Error("Failed to subscribe to restore assignments", "error", err)
panic(err)
}
defer func() {
if err := n.restoreNodesRegistry.UnsubscribeNodeForRestoresAssignments(); err != nil {
n.logger.Error("Failed to unsubscribe from restore assignments", "error", err)
if err := n.restoreNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), restoreNode); err != nil {
n.logger.Error("Failed to register node in registry", "error", err)
panic(err)
}
}()
ticker := time.NewTicker(heartbeatTickerInterval)
defer ticker.Stop()
n.logger.Info("Restore node started", "nodeID", n.nodeID, "throughput", throughputMBs)
for {
select {
case <-ctx.Done():
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
if err := n.restoreNodesRegistry.UnregisterNodeFromRegistry(restoreNode); err != nil {
n.logger.Error("Failed to unregister node from registry", "error", err)
restoreHandler := func(restoreID uuid.UUID, isCallNotifier bool) {
n.MakeRestore(restoreID)
if err := n.restoreNodesRegistry.PublishRestoreCompletion(n.nodeID, restoreID); err != nil {
n.logger.Error(
"Failed to publish restore completion",
"error",
err,
"restoreID",
restoreID,
)
}
return
case <-ticker.C:
n.sendHeartbeat(&restoreNode)
}
err := n.restoreNodesRegistry.SubscribeNodeForRestoresAssignment(
n.nodeID,
restoreHandler,
)
if err != nil {
n.logger.Error("Failed to subscribe to restore assignments", "error", err)
panic(err)
}
defer func() {
if err := n.restoreNodesRegistry.UnsubscribeNodeForRestoresAssignments(); err != nil {
n.logger.Error("Failed to unsubscribe from restore assignments", "error", err)
}
}()
ticker := time.NewTicker(heartbeatTickerInterval)
defer ticker.Stop()
n.logger.Info("Restore node started", "nodeID", n.nodeID, "throughput", throughputMBs)
for {
select {
case <-ctx.Done():
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
if err := n.restoreNodesRegistry.UnregisterNodeFromRegistry(restoreNode); err != nil {
n.logger.Error("Failed to unregister node from registry", "error", err)
}
return
case <-ticker.C:
n.sendHeartbeat(&restoreNode)
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", n))
}
}
@@ -176,6 +195,11 @@ func (n *RestorerNode) MakeRestore(restoreID uuid.UUID) {
start := time.Now().UTC()
// Create cancellable context
ctx, cancel := context.WithCancel(context.Background())
n.restoreCancelManager.RegisterTask(restore.ID, cancel)
defer n.restoreCancelManager.UnregisterTask(restore.ID)
// Create restoring database from cached credentials
restoringToDB := &databases.Database{
Type: database.Type,
@@ -204,6 +228,7 @@ func (n *RestorerNode) MakeRestore(restoreID uuid.UUID) {
}
err = n.restoreBackupUsecase.Execute(
ctx,
backupConfig,
*restore,
database,
@@ -216,6 +241,29 @@ func (n *RestorerNode) MakeRestore(restoreID uuid.UUID) {
if err != nil {
errMsg := err.Error()
// Check if restore was cancelled
isCancelled := strings.Contains(errMsg, "restore cancelled") ||
strings.Contains(errMsg, "context canceled") ||
errors.Is(err, context.Canceled)
isShutdown := strings.Contains(errMsg, "shutdown")
if isCancelled && !isShutdown {
n.logger.Warn("Restore was cancelled by user or system",
"restoreId", restore.ID,
"isCancelled", isCancelled,
"isShutdown", isShutdown,
)
restore.Status = restores_core.RestoreStatusCanceled
restore.RestoreDurationMs = time.Since(start).Milliseconds()
if err := n.restoreRepository.Save(restore); err != nil {
n.logger.Error("Failed to save cancelled restore", "error", err)
}
return
}
n.logger.Error("Restore execution failed",
"restoreId", restore.ID,
"backupId", backup.ID,

View File

@@ -6,6 +6,7 @@ import (
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
@@ -139,7 +140,7 @@ func Test_MakeRestore_WhenTaskStarts_CacheDeletedImmediately(t *testing.T) {
// Cache DB credentials separately
dbCache := &RestoreDatabaseCache{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Host: "localhost",
Host: config.GetEnv().TestLocalhost,
Port: 5432,
Username: "test",
Password: "test",

View File

@@ -4,6 +4,8 @@ import (
"context"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
@@ -34,51 +36,64 @@ type RestoresScheduler struct {
restorerNode *RestorerNode
cacheUtil *cache_utils.CacheUtil[RestoreDatabaseCache]
completionSubscriptionID uuid.UUID
runOnce sync.Once
hasRun atomic.Bool
}
func (s *RestoresScheduler) Run(ctx context.Context) {
s.lastCheckTime = time.Now().UTC()
wasAlreadyRun := s.hasRun.Load()
if config.GetEnv().IsManyNodesMode {
// wait other nodes to start
time.Sleep(schedulerStartupDelay)
}
s.runOnce.Do(func() {
s.hasRun.Store(true)
if err := s.failRestoresInProgress(); err != nil {
s.logger.Error("Failed to fail restores in progress", "error", err)
panic(err)
}
s.lastCheckTime = time.Now().UTC()
err := s.restoreNodesRegistry.SubscribeForRestoresCompletions(s.onRestoreCompleted)
if err != nil {
s.logger.Error("Failed to subscribe to restore completions", "error", err)
panic(err)
}
defer func() {
if err := s.restoreNodesRegistry.UnsubscribeForRestoresCompletions(); err != nil {
s.logger.Error("Failed to unsubscribe from restore completions", "error", err)
if config.GetEnv().IsManyNodesMode {
// wait other nodes to start
time.Sleep(schedulerStartupDelay)
}
}()
if ctx.Err() != nil {
return
}
if err := s.failRestoresInProgress(); err != nil {
s.logger.Error("Failed to fail restores in progress", "error", err)
panic(err)
}
ticker := time.NewTicker(schedulerTickerInterval)
defer ticker.Stop()
err := s.restoreNodesRegistry.SubscribeForRestoresCompletions(s.onRestoreCompleted)
if err != nil {
s.logger.Error("Failed to subscribe to restore completions", "error", err)
panic(err)
}
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.checkDeadNodesAndFailRestores(); err != nil {
s.logger.Error("Failed to check dead nodes and fail restores", "error", err)
defer func() {
if err := s.restoreNodesRegistry.UnsubscribeForRestoresCompletions(); err != nil {
s.logger.Error("Failed to unsubscribe from restore completions", "error", err)
}
}()
s.lastCheckTime = time.Now().UTC()
if ctx.Err() != nil {
return
}
ticker := time.NewTicker(schedulerTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.checkDeadNodesAndFailRestores(); err != nil {
s.logger.Error("Failed to check dead nodes and fail restores", "error", err)
}
s.lastCheckTime = time.Now().UTC()
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
}

View File

@@ -4,6 +4,7 @@ import (
"testing"
"time"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
@@ -424,7 +425,8 @@ func Test_StartRestore_RestoreCompletes_DecrementsActiveTaskCount(t *testing.T)
cache_utils.ClearAllCache()
// Start scheduler so it can handle task completions
schedulerCancel := StartSchedulerForTest(t)
scheduler := CreateTestRestoresScheduler()
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
restorerNode := CreateTestRestorerNode()
@@ -485,7 +487,7 @@ func Test_StartRestore_RestoreCompletes_DecrementsActiveTaskCount(t *testing.T)
err = restoreRepository.Save(restore)
assert.NoError(t, err)
err = GetRestoresScheduler().StartRestore(restore.ID, nil)
err = scheduler.StartRestore(restore.ID, nil)
assert.NoError(t, err)
// Wait for restore to complete
@@ -524,7 +526,8 @@ func Test_StartRestore_RestoreFails_DecrementsActiveTaskCount(t *testing.T) {
cache_utils.ClearAllCache()
// Start scheduler so it can handle task completions
schedulerCancel := StartSchedulerForTest(t)
scheduler := CreateTestRestoresScheduler()
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
restorerNode := CreateTestRestorerNode()
@@ -585,7 +588,7 @@ func Test_StartRestore_RestoreFails_DecrementsActiveTaskCount(t *testing.T) {
err = restoreRepository.Save(restore)
assert.NoError(t, err)
err = GetRestoresScheduler().StartRestore(restore.ID, nil)
err = scheduler.StartRestore(restore.ID, nil)
assert.NoError(t, err)
// Wait for restore to fail
@@ -679,7 +682,7 @@ func Test_StartRestore_CredentialsStoredEncryptedInCache(t *testing.T) {
// Create PostgreSQL database credentials with plaintext password
postgresDB := &postgresql.PostgresqlDatabase{
Host: "localhost",
Host: config.GetEnv().TestLocalhost,
Port: 5432,
Username: "testuser",
Password: plaintextPassword,
@@ -717,7 +720,7 @@ func Test_StartRestore_CredentialsStoredEncryptedInCache(t *testing.T) {
"Cached password should match the encrypted version")
// Verify other fields are present
assert.Equal(t, "localhost", cachedData.PostgresqlDatabase.Host)
assert.Equal(t, config.GetEnv().TestLocalhost, cachedData.PostgresqlDatabase.Host)
assert.Equal(t, 5432, cachedData.PostgresqlDatabase.Port)
assert.Equal(t, "testuser", cachedData.PostgresqlDatabase.Username)
assert.Equal(t, "testdb", *cachedData.PostgresqlDatabase.Database)
@@ -729,7 +732,8 @@ func Test_StartRestore_CredentialsRemovedAfterRestoreStarts(t *testing.T) {
cache_utils.ClearAllCache()
// Start scheduler so it can handle task assignments
schedulerCancel := StartSchedulerForTest(t)
scheduler := CreateTestRestoresScheduler()
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
// Create mock restorer node with credential capture usecase
@@ -789,7 +793,7 @@ func Test_StartRestore_CredentialsRemovedAfterRestoreStarts(t *testing.T) {
// Create PostgreSQL database credentials
// Database field is nil to avoid PopulateDbData trying to connect
postgresDB := &postgresql.PostgresqlDatabase{
Host: "localhost",
Host: config.GetEnv().TestLocalhost,
Port: 5432,
Username: "testuser",
Password: plaintextPassword,
@@ -810,7 +814,7 @@ func Test_StartRestore_CredentialsRemovedAfterRestoreStarts(t *testing.T) {
}
// Call StartRestore to cache credentials and trigger restore
err = GetRestoresScheduler().StartRestore(restore.ID, dbCache)
err = scheduler.StartRestore(restore.ID, dbCache)
assert.NoError(t, err)
// Wait for mock usecase to be called (with timeout)
@@ -829,7 +833,7 @@ func Test_StartRestore_CredentialsRemovedAfterRestoreStarts(t *testing.T) {
// Verify mock received valid credentials
assert.NotNil(t, capturedDB, "Captured database should not be nil")
assert.NotNil(t, capturedDB.Postgresql, "PostgreSQL credentials should be provided to usecase")
assert.Equal(t, "localhost", capturedDB.Postgresql.Host)
assert.Equal(t, config.GetEnv().TestLocalhost, capturedDB.Postgresql.Host)
assert.Equal(t, 5432, capturedDB.Postgresql.Port)
assert.Equal(t, "testuser", capturedDB.Postgresql.Username)
assert.NotEmpty(t, capturedDB.Postgresql.Password, "Password should be provided to usecase")

View File

@@ -3,12 +3,15 @@ package restoring
import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
@@ -17,6 +20,7 @@ import (
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/restores/usecases"
"databasus-backend/internal/features/storages"
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/util/encryption"
@@ -36,18 +40,59 @@ func CreateTestRouter() *gin.Engine {
func CreateTestRestorerNode() *RestorerNode {
return &RestorerNode{
uuid.New(),
databases.GetDatabaseService(),
backups.GetBackupService(),
encryption.GetFieldEncryptor(),
nodeID: uuid.New(),
databaseService: databases.GetDatabaseService(),
backupService: backups.GetBackupService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
restoreRepository: restoreRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
restoreNodesRegistry: restoreNodesRegistry,
logger: logger.GetLogger(),
restoreBackupUsecase: usecases.GetRestoreBackupUsecase(),
cacheUtil: restoreDatabaseCache,
restoreCancelManager: tasks_cancellation.GetTaskCancelManager(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}
func CreateTestRestorerNodeWithUsecase(usecase restores_core.RestoreBackupUsecase) *RestorerNode {
return &RestorerNode{
nodeID: uuid.New(),
databaseService: databases.GetDatabaseService(),
backupService: backups.GetBackupService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
restoreRepository: restoreRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
restoreNodesRegistry: restoreNodesRegistry,
logger: logger.GetLogger(),
restoreBackupUsecase: usecase,
cacheUtil: restoreDatabaseCache,
restoreCancelManager: tasks_cancellation.GetTaskCancelManager(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}
func CreateTestRestoresScheduler() *RestoresScheduler {
return &RestoresScheduler{
restoreRepository,
backups_config.GetBackupConfigService(),
backups.GetBackupService(),
storages.GetStorageService(),
backups_config.GetBackupConfigService(),
restoreNodesRegistry,
time.Now().UTC(),
logger.GetLogger(),
usecases.GetRestoreBackupUsecase(),
make(map[uuid.UUID]RestoreToNodeRelation),
restorerNode,
restoreDatabaseCache,
time.Time{},
uuid.Nil,
sync.Once{},
atomic.Bool{},
}
}
@@ -128,13 +173,13 @@ func StartRestorerNodeForTest(t *testing.T, restorerNode *RestorerNode) context.
// StartSchedulerForTest starts the RestoresScheduler in a goroutine for testing.
// The scheduler subscribes to task completions and manages restore lifecycle.
// Returns a context cancel function that should be deferred to stop the scheduler.
func StartSchedulerForTest(t *testing.T) context.CancelFunc {
func StartSchedulerForTest(t *testing.T, scheduler *RestoresScheduler) context.CancelFunc {
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
GetRestoresScheduler().Run(ctx)
scheduler.Run(ctx)
close(done)
}()
@@ -275,7 +320,7 @@ func CreateTestRestore(
BackupID: backup.ID,
Status: status,
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Host: "localhost",
Host: config.GetEnv().TestLocalhost,
Port: 5432,
Username: "test",
Password: "test",

View File

@@ -1,6 +1,7 @@
package restores
import (
"databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
@@ -11,6 +12,7 @@ import (
"databasus-backend/internal/features/restores/restoring"
"databasus-backend/internal/features/restores/usecases"
"databasus-backend/internal/features/storages"
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
users_models "databasus-backend/internal/features/users/models"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/encryption"
@@ -35,6 +37,7 @@ type RestoreService struct {
auditLogService *audit_logs.AuditLogService
fieldEncryptor encryption.FieldEncryptor
diskService *disk.DiskService
taskCancelManager *tasks_cancellation.TaskCancelManager
}
func (s *RestoreService) OnBeforeBackupRemove(backup *backups_core.Backup) error {
@@ -125,6 +128,13 @@ func (s *RestoreService) RestoreBackupWithAuth(
return err
}
if config.GetEnv().IsCloud {
// in cloud mode we use only single thread mode,
// because otherwise we will exhaust local storage
// space (instead of streaming from S3 directly to DB)
requestDTO.PostgresqlDatabase.CpuCount = 1
}
if err := s.validateVersionCompatibility(backupDatabase, requestDTO); err != nil {
return err
}
@@ -134,6 +144,11 @@ func (s *RestoreService) RestoreBackupWithAuth(
return err
}
// Validate no parallel restores for the same database
if err := s.validateNoParallelRestores(backup.DatabaseID); err != nil {
return err
}
// Create restore record with the request configuration
restore := restores_core.Restore{
ID: uuid.New(),
@@ -340,3 +355,71 @@ func (s *RestoreService) validateDiskSpace(
return nil
}
func (s *RestoreService) validateNoParallelRestores(databaseID uuid.UUID) error {
inProgressRestores, err := s.restoreRepository.FindInProgressRestoresByDatabaseID(databaseID)
if err != nil {
return fmt.Errorf("failed to check for in-progress restores: %w", err)
}
isInProgress := len(inProgressRestores) > 0
if isInProgress {
return errors.New(
"another restore is already in progress for this database. Please wait for it to complete or cancel it before starting a new restore",
)
}
return nil
}
func (s *RestoreService) CancelRestore(
user *users_models.User,
restoreID uuid.UUID,
) error {
restore, err := s.restoreRepository.FindByID(restoreID)
if err != nil {
return err
}
backup, err := s.backupService.GetBackup(restore.BackupID)
if err != nil {
return err
}
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
if err != nil {
return err
}
if database.WorkspaceID == nil {
return errors.New("cannot cancel restore for database without workspace")
}
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user)
if err != nil {
return err
}
if !canManage {
return errors.New("insufficient permissions to cancel restore for this database")
}
if restore.Status != restores_core.RestoreStatusInProgress {
return errors.New("restore is not in progress")
}
if err := s.taskCancelManager.CancelTask(restoreID); err != nil {
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Restore cancelled for database: %s (ID: %s)",
database.Name,
restoreID.String(),
),
&user.ID,
database.WorkspaceID,
)
return nil
}

View File

@@ -36,6 +36,7 @@ type RestoreMariadbBackupUsecase struct {
}
func (uc *RestoreMariadbBackupUsecase) Execute(
parentCtx context.Context,
originalDB *databases.Database,
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
@@ -79,6 +80,7 @@ func (uc *RestoreMariadbBackupUsecase) Execute(
}
return uc.restoreFromStorage(
parentCtx,
originalDB,
tools.GetMariadbExecutable(
tools.MariadbExecutableMariadb,
@@ -95,6 +97,7 @@ func (uc *RestoreMariadbBackupUsecase) Execute(
}
func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
parentCtx context.Context,
database *databases.Database,
mariadbBin string,
args []string,
@@ -103,7 +106,7 @@ func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
storage *storages.Storage,
mdbConfig *mariadbtypes.MariadbDatabase,
) error {
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
defer cancel()
go func() {
@@ -113,6 +116,9 @@ func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
select {
case <-ctx.Done():
return
case <-parentCtx.Done():
cancel()
return
case <-ticker.C:
if config.IsShouldShutdown() {
cancel()
@@ -213,6 +219,15 @@ func (uc *RestoreMariadbBackupUsecase) executeMariadbRestore(
waitErr := cmd.Wait()
stderrOutput := <-stderrCh
// Check for cancellation
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.Canceled) {
return fmt.Errorf("restore cancelled")
}
default:
}
if config.IsShouldShutdown() {
return fmt.Errorf("restore cancelled due to shutdown")
}

View File

@@ -36,6 +36,7 @@ type RestoreMongodbBackupUsecase struct {
}
func (uc *RestoreMongodbBackupUsecase) Execute(
parentCtx context.Context,
originalDB *databases.Database,
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
@@ -76,6 +77,7 @@ func (uc *RestoreMongodbBackupUsecase) Execute(
args := uc.buildMongorestoreArgs(mdb, decryptedPassword, sourceDatabase)
return uc.restoreFromStorage(
parentCtx,
tools.GetMongodbExecutable(
tools.MongodbExecutableMongorestore,
config.GetEnv().EnvMode,
@@ -122,12 +124,13 @@ func (uc *RestoreMongodbBackupUsecase) buildMongorestoreArgs(
}
func (uc *RestoreMongodbBackupUsecase) restoreFromStorage(
parentCtx context.Context,
mongorestoreBin string,
args []string,
backup *backups_core.Backup,
storage *storages.Storage,
) error {
ctx, cancel := context.WithTimeout(context.Background(), restoreTimeout)
ctx, cancel := context.WithTimeout(parentCtx, restoreTimeout)
defer cancel()
go func() {
@@ -137,6 +140,9 @@ func (uc *RestoreMongodbBackupUsecase) restoreFromStorage(
select {
case <-ctx.Done():
return
case <-parentCtx.Done():
cancel()
return
case <-ticker.C:
if config.IsShouldShutdown() {
cancel()
@@ -218,6 +224,15 @@ func (uc *RestoreMongodbBackupUsecase) executeMongoRestore(
waitErr := cmd.Wait()
stderrOutput := <-stderrCh
// Check for cancellation
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.Canceled) {
return fmt.Errorf("restore cancelled")
}
default:
}
if config.IsShouldShutdown() {
return fmt.Errorf("restore cancelled due to shutdown")
}

View File

@@ -36,6 +36,7 @@ type RestoreMysqlBackupUsecase struct {
}
func (uc *RestoreMysqlBackupUsecase) Execute(
parentCtx context.Context,
originalDB *databases.Database,
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
@@ -78,6 +79,7 @@ func (uc *RestoreMysqlBackupUsecase) Execute(
}
return uc.restoreFromStorage(
parentCtx,
originalDB,
tools.GetMysqlExecutable(
my.Version,
@@ -94,6 +96,7 @@ func (uc *RestoreMysqlBackupUsecase) Execute(
}
func (uc *RestoreMysqlBackupUsecase) restoreFromStorage(
parentCtx context.Context,
database *databases.Database,
mysqlBin string,
args []string,
@@ -102,7 +105,7 @@ func (uc *RestoreMysqlBackupUsecase) restoreFromStorage(
storage *storages.Storage,
myConfig *mysqltypes.MysqlDatabase,
) error {
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
defer cancel()
go func() {
@@ -112,6 +115,9 @@ func (uc *RestoreMysqlBackupUsecase) restoreFromStorage(
select {
case <-ctx.Done():
return
case <-parentCtx.Done():
cancel()
return
case <-ticker.C:
if config.IsShouldShutdown() {
cancel()
@@ -204,6 +210,15 @@ func (uc *RestoreMysqlBackupUsecase) executeMysqlRestore(
waitErr := cmd.Wait()
stderrOutput := <-stderrCh
// Check for cancellation
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.Canceled) {
return fmt.Errorf("restore cancelled")
}
default:
}
if config.IsShouldShutdown() {
return fmt.Errorf("restore cancelled due to shutdown")
}

View File

@@ -35,6 +35,7 @@ type RestorePostgresqlBackupUsecase struct {
}
func (uc *RestorePostgresqlBackupUsecase) Execute(
parentCtx context.Context,
originalDB *databases.Database,
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
@@ -64,6 +65,13 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
return fmt.Errorf("target database name is required for pg_restore")
}
// Validate CPU count constraint for cloud environments
if config.GetEnv().IsCloud && pg.CpuCount > 1 {
return fmt.Errorf(
"parallel restore (CPU count > 1) is not supported in cloud mode due to storage constraints. Please use CPU count = 1",
)
}
pgBin := tools.GetPostgresqlExecutable(
pg.Version,
"pg_restore",
@@ -73,6 +81,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
// All PostgreSQL backups are now custom format (-Fc)
return uc.restoreCustomType(
parentCtx,
originalDB,
pgBin,
backup,
@@ -84,6 +93,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
// restoreCustomType restores a backup in custom type (-Fc)
func (uc *RestorePostgresqlBackupUsecase) restoreCustomType(
parentCtx context.Context,
originalDB *databases.Database,
pgBin string,
backup *backups_core.Backup,
@@ -102,15 +112,24 @@ func (uc *RestorePostgresqlBackupUsecase) restoreCustomType(
// If excluding extensions, we must use file-based restore (requires TOC file generation)
// Also use file-based restore for parallel jobs (multiple CPUs)
if isExcludeExtensions || pg.CpuCount > 1 {
return uc.restoreViaFile(originalDB, pgBin, backup, storage, pg, isExcludeExtensions)
return uc.restoreViaFile(
parentCtx,
originalDB,
pgBin,
backup,
storage,
pg,
isExcludeExtensions,
)
}
// Single CPU without extension exclusion: stream directly via stdin
return uc.restoreViaStdin(originalDB, pgBin, backup, storage, pg)
return uc.restoreViaStdin(parentCtx, originalDB, pgBin, backup, storage, pg)
}
// restoreViaStdin streams backup via stdin for single CPU restore
func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
parentCtx context.Context,
originalDB *databases.Database,
pgBin string,
backup *backups_core.Backup,
@@ -133,10 +152,10 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
"--no-acl",
}
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
defer cancel()
// Monitor for shutdown and cancel context if needed
// Monitor for shutdown and parent cancellation
go func() {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
@@ -145,6 +164,9 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
select {
case <-ctx.Done():
return
case <-parentCtx.Done():
cancel()
return
case <-ticker.C:
if config.IsShouldShutdown() {
cancel()
@@ -296,6 +318,15 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
stderrOutput := <-stderrCh
copyErr := <-copyErrCh
// Check for cancellation
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.Canceled) {
return fmt.Errorf("restore cancelled")
}
default:
}
// Check for shutdown before finalizing
if config.IsShouldShutdown() {
return fmt.Errorf("restore cancelled due to shutdown")
@@ -307,6 +338,15 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
}
if waitErr != nil {
// Check for cancellation again
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.Canceled) {
return fmt.Errorf("restore cancelled")
}
default:
}
if config.IsShouldShutdown() {
return fmt.Errorf("restore cancelled due to shutdown")
}
@@ -319,6 +359,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
// restoreViaFile downloads backup and uses parallel jobs for multi-CPU restore
func (uc *RestorePostgresqlBackupUsecase) restoreViaFile(
parentCtx context.Context,
originalDB *databases.Database,
pgBin string,
backup *backups_core.Backup,
@@ -354,6 +395,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaFile(
}
return uc.restoreFromStorage(
parentCtx,
originalDB,
pgBin,
args,
@@ -367,6 +409,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaFile(
// restoreFromStorage restores backup data from storage using pg_restore
func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
parentCtx context.Context,
database *databases.Database,
pgBin string,
args []string,
@@ -386,10 +429,10 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
isExcludeExtensions,
)
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
defer cancel()
// Monitor for shutdown and cancel context if needed
// Monitor for shutdown and parent cancellation
go func() {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
@@ -398,6 +441,9 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
select {
case <-ctx.Done():
return
case <-parentCtx.Done():
cancel()
return
case <-ticker.C:
if config.IsShouldShutdown() {
cancel()
@@ -624,12 +670,30 @@ func (uc *RestorePostgresqlBackupUsecase) executePgRestore(
waitErr := cmd.Wait()
stderrOutput := <-stderrCh
// Check for cancellation
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.Canceled) {
return fmt.Errorf("restore cancelled")
}
default:
}
// Check for shutdown before finalizing
if config.IsShouldShutdown() {
return fmt.Errorf("restore cancelled due to shutdown")
}
if waitErr != nil {
// Check for cancellation again
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.Canceled) {
return fmt.Errorf("restore cancelled")
}
default:
}
if config.IsShouldShutdown() {
return fmt.Errorf("restore cancelled due to shutdown")
}

View File

@@ -1,6 +1,7 @@
package usecases
import (
"context"
"errors"
backups_core "databasus-backend/internal/features/backups/backups/core"
@@ -22,6 +23,7 @@ type RestoreBackupUsecase struct {
}
func (uc *RestoreBackupUsecase) Execute(
ctx context.Context,
backupConfig *backups_config.BackupConfig,
restore restores_core.Restore,
originalDB *databases.Database,
@@ -33,6 +35,7 @@ func (uc *RestoreBackupUsecase) Execute(
switch originalDB.Type {
case databases.DatabaseTypePostgres:
return uc.restorePostgresqlBackupUsecase.Execute(
ctx,
originalDB,
restoringToDB,
backupConfig,
@@ -43,6 +46,7 @@ func (uc *RestoreBackupUsecase) Execute(
)
case databases.DatabaseTypeMysql:
return uc.restoreMysqlBackupUsecase.Execute(
ctx,
originalDB,
restoringToDB,
backupConfig,
@@ -52,6 +56,7 @@ func (uc *RestoreBackupUsecase) Execute(
)
case databases.DatabaseTypeMariadb:
return uc.restoreMariadbBackupUsecase.Execute(
ctx,
originalDB,
restoringToDB,
backupConfig,
@@ -61,6 +66,7 @@ func (uc *RestoreBackupUsecase) Execute(
)
case databases.DatabaseTypeMongodb:
return uc.restoreMongodbBackupUsecase.Execute(
ctx,
originalDB,
restoringToDB,
backupConfig,

View File

@@ -58,7 +58,8 @@ func (c *StorageController) SaveStorage(ctx *gin.Context) {
}
if err := c.storageService.SaveStorage(user, request.WorkspaceID, &request); err != nil {
if errors.Is(err, ErrInsufficientPermissionsToManageStorage) {
if errors.Is(err, ErrInsufficientPermissionsToManageStorage) ||
errors.Is(err, ErrLocalStorageNotAllowedInCloudMode) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
@@ -325,7 +326,11 @@ func (c *StorageController) TestStorageConnectionDirect(ctx *gin.Context) {
return
}
if err := c.storageService.TestStorageConnectionDirect(&request); err != nil {
if err := c.storageService.TestStorageConnectionDirect(user, &request); err != nil {
if errors.Is(err, ErrLocalStorageNotAllowedInCloudMode) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}

View File

@@ -6,6 +6,7 @@ import (
"strings"
"testing"
"databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
azure_blob_storage "databasus-backend/internal/features/storages/models/azure_blob"
ftp_storage "databasus-backend/internal/features/storages/models/ftp"
@@ -83,7 +84,7 @@ func Test_SaveNewStorage_StorageReturnedViaGet(t *testing.T) {
assert.Contains(t, storages, savedStorage)
deleteStorage(t, router, savedStorage.ID, workspace.ID, owner.Token)
deleteStorage(t, router, savedStorage.ID, owner.Token)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
@@ -121,7 +122,169 @@ func Test_UpdateExistingStorage_UpdatedStorageReturnedViaGet(t *testing.T) {
assert.Equal(t, updatedName, updatedStorage.Name)
assert.Equal(t, savedStorage.ID, updatedStorage.ID)
deleteStorage(t, router, updatedStorage.ID, workspace.ID, owner.Token)
deleteStorage(t, router, updatedStorage.ID, owner.Token)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_CreateSystemStorage_OnlyAdminCanCreate_MemberGetsForbidden(t *testing.T) {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", admin, router)
// Admin can create system storage
systemStorage := createNewStorage(workspace.ID)
systemStorage.IsSystem = true
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+admin.Token,
*systemStorage,
http.StatusOK,
&savedStorage,
)
assert.True(t, savedStorage.IsSystem)
assert.Equal(t, systemStorage.Name, savedStorage.Name)
// Member cannot create system storage
memberSystemStorage := createNewStorage(workspace.ID)
memberSystemStorage.IsSystem = true
resp := test_utils.MakePostRequest(
t,
router,
"/api/v1/storages",
"Bearer "+member.Token,
*memberSystemStorage,
http.StatusForbidden,
)
assert.Contains(t, string(resp.Body), "insufficient permissions")
deleteStorage(t, router, savedStorage.ID, admin.Token)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_UpdateStorageIsSystem_OnlyAdminCanUpdate_MemberGetsForbidden(t *testing.T) {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", admin, router)
// Create a regular storage
storage := createNewStorage(workspace.ID)
storage.IsSystem = false
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+admin.Token,
*storage,
http.StatusOK,
&savedStorage,
)
assert.False(t, savedStorage.IsSystem)
// Admin can update to system
savedStorage.IsSystem = true
var updatedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+admin.Token,
savedStorage,
http.StatusOK,
&updatedStorage,
)
assert.True(t, updatedStorage.IsSystem)
// Member cannot update system storage
updatedStorage.Name = "Updated by member"
resp := test_utils.MakePostRequest(
t,
router,
"/api/v1/storages",
"Bearer "+member.Token,
updatedStorage,
http.StatusForbidden,
)
assert.Contains(t, string(resp.Body), "insufficient permissions")
deleteStorage(t, router, updatedStorage.ID, admin.Token)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_UpdateSystemStorage_CannotChangeToPrivate_ReturnsBadRequest(t *testing.T) {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", admin, router)
// Create system storage
storage := createNewStorage(workspace.ID)
storage.IsSystem = true
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+admin.Token,
*storage,
http.StatusOK,
&savedStorage,
)
assert.True(t, savedStorage.IsSystem)
// Attempt to change system storage to non-system (should fail)
savedStorage.IsSystem = false
resp := test_utils.MakePostRequest(
t,
router,
"/api/v1/storages",
"Bearer "+admin.Token,
savedStorage,
http.StatusBadRequest,
)
assert.Contains(t, string(resp.Body), "system storage cannot be changed to non-system")
// Verify storage is still system
var retrievedStorage Storage
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
"Bearer "+admin.Token,
http.StatusOK,
&retrievedStorage,
)
assert.True(t, retrievedStorage.IsSystem)
// Admin can update other fields while keeping IsSystem=true
savedStorage.IsSystem = true
savedStorage.Name = "Updated System Storage"
var updatedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+admin.Token,
savedStorage,
http.StatusOK,
&updatedStorage,
)
assert.True(t, updatedStorage.IsSystem)
assert.Equal(t, "Updated System Storage", updatedStorage.Name)
deleteStorage(t, router, updatedStorage.ID, admin.Token)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
@@ -204,7 +367,7 @@ func Test_TestExistingStorageConnection_ConnectionEstablished(t *testing.T) {
assert.Contains(t, string(response.Body), "successful")
deleteStorage(t, router, savedStorage.ID, workspace.ID, owner.Token)
deleteStorage(t, router, savedStorage.ID, owner.Token)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
@@ -300,7 +463,14 @@ func Test_WorkspaceRolePermissions(t *testing.T) {
fmt.Sprintf("/api/v1/storages?workspace_id=%s", workspace.ID.String()),
"Bearer "+testUserToken, http.StatusOK, &storages,
)
assert.Len(t, storages, 1)
// Count only non-system storages for this workspace
nonSystemStorages := 0
for _, s := range storages {
if !s.IsSystem {
nonSystemStorages++
}
}
assert.Equal(t, 1, nonSystemStorages)
// Test CREATE storage
createStatusCode := http.StatusOK
@@ -355,16 +525,514 @@ func Test_WorkspaceRolePermissions(t *testing.T) {
// Cleanup
if tt.canCreate {
deleteStorage(t, router, savedStorage.ID, workspace.ID, owner.Token)
deleteStorage(t, router, savedStorage.ID, owner.Token)
}
if !tt.canDelete {
deleteStorage(t, router, ownerStorage.ID, workspace.ID, owner.Token)
deleteStorage(t, router, ownerStorage.ID, owner.Token)
}
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
func Test_SystemStorage_AdminOnlyOperations(t *testing.T) {
tests := []struct {
name string
operation string
isAdmin bool
expectSuccess bool
expectedStatus int
}{
{
name: "admin can create system storage",
operation: "create",
isAdmin: true,
expectSuccess: true,
expectedStatus: http.StatusOK,
},
{
name: "member cannot create system storage",
operation: "create",
isAdmin: false,
expectSuccess: false,
expectedStatus: http.StatusForbidden,
},
{
name: "admin can update storage to make it system",
operation: "update_to_system",
isAdmin: true,
expectSuccess: true,
expectedStatus: http.StatusOK,
},
{
name: "member cannot update storage to make it system",
operation: "update_to_system",
isAdmin: false,
expectSuccess: false,
expectedStatus: http.StatusForbidden,
},
{
name: "admin can update system storage",
operation: "update_system",
isAdmin: true,
expectSuccess: true,
expectedStatus: http.StatusOK,
},
{
name: "member cannot update system storage",
operation: "update_system",
isAdmin: false,
expectSuccess: false,
expectedStatus: http.StatusForbidden,
},
{
name: "admin can delete system storage",
operation: "delete",
isAdmin: true,
expectSuccess: true,
expectedStatus: http.StatusOK,
},
{
name: "member cannot delete system storage",
operation: "delete",
isAdmin: false,
expectSuccess: false,
expectedStatus: http.StatusForbidden,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createRouter()
GetStorageService().SetStorageDatabaseCounter(&mockStorageDatabaseCounter{})
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
var testUserToken string
if tt.isAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUserToken = admin.Token
} else {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
users_enums.WorkspaceRoleMember,
owner.Token,
router,
)
testUserToken = member.Token
}
switch tt.operation {
case "create":
systemStorage := &Storage{
WorkspaceID: workspace.ID,
Type: StorageTypeLocal,
Name: "Test System Storage " + uuid.New().String(),
IsSystem: true,
LocalStorage: &local_storage.LocalStorage{},
}
if tt.expectSuccess {
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+testUserToken,
*systemStorage,
tt.expectedStatus,
&savedStorage,
)
assert.NotEmpty(t, savedStorage.ID)
assert.True(t, savedStorage.IsSystem)
deleteStorage(t, router, savedStorage.ID, testUserToken)
} else {
resp := test_utils.MakePostRequest(
t,
router,
"/api/v1/storages",
"Bearer "+testUserToken,
*systemStorage,
tt.expectedStatus,
)
assert.Contains(t, string(resp.Body), "insufficient permissions")
}
case "update_to_system":
// Owner creates private storage first
privateStorage := createNewStorage(workspace.ID)
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
*privateStorage,
http.StatusOK,
&savedStorage,
)
// Test user attempts to make it system
savedStorage.IsSystem = true
if tt.expectSuccess {
var updatedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+testUserToken,
savedStorage,
tt.expectedStatus,
&updatedStorage,
)
assert.True(t, updatedStorage.IsSystem)
deleteStorage(t, router, savedStorage.ID, testUserToken)
} else {
resp := test_utils.MakePostRequest(
t,
router,
"/api/v1/storages",
"Bearer "+testUserToken,
savedStorage,
tt.expectedStatus,
)
assert.Contains(t, string(resp.Body), "insufficient permissions")
deleteStorage(t, router, savedStorage.ID, owner.Token)
}
case "update_system":
// Admin creates system storage first
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
systemStorage := &Storage{
WorkspaceID: workspace.ID,
Type: StorageTypeLocal,
Name: "Test System Storage " + uuid.New().String(),
IsSystem: true,
LocalStorage: &local_storage.LocalStorage{},
}
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+admin.Token,
*systemStorage,
http.StatusOK,
&savedStorage,
)
// Test user attempts to update system storage
savedStorage.Name = "Updated System Storage " + uuid.New().String()
if tt.expectSuccess {
var updatedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+testUserToken,
savedStorage,
tt.expectedStatus,
&updatedStorage,
)
assert.Equal(t, savedStorage.Name, updatedStorage.Name)
assert.True(t, updatedStorage.IsSystem)
deleteStorage(t, router, savedStorage.ID, testUserToken)
} else {
resp := test_utils.MakePostRequest(
t,
router,
"/api/v1/storages",
"Bearer "+testUserToken,
savedStorage,
tt.expectedStatus,
)
assert.Contains(t, string(resp.Body), "insufficient permissions")
deleteStorage(t, router, savedStorage.ID, admin.Token)
}
case "delete":
// Admin creates system storage first
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
systemStorage := &Storage{
WorkspaceID: workspace.ID,
Type: StorageTypeLocal,
Name: "Test System Storage " + uuid.New().String(),
IsSystem: true,
LocalStorage: &local_storage.LocalStorage{},
}
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+admin.Token,
*systemStorage,
http.StatusOK,
&savedStorage,
)
// Test user attempts to delete system storage
if tt.expectSuccess {
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
"Bearer "+testUserToken,
tt.expectedStatus,
)
} else {
resp := test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
"Bearer "+testUserToken,
tt.expectedStatus,
)
assert.Contains(t, string(resp.Body), "insufficient permissions")
deleteStorage(t, router, savedStorage.ID, admin.Token)
}
}
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
func Test_GetStorages_SystemStorageIncludedForAllUsers(t *testing.T) {
router := createRouter()
GetStorageService().SetStorageDatabaseCounter(&mockStorageDatabaseCounter{})
// Create two workspaces with different owners
ownerA := users_testing.CreateTestUser(users_enums.UserRoleMember)
ownerB := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaceA := workspaces_testing.CreateTestWorkspace("Workspace A", ownerA, router)
workspaceB := workspaces_testing.CreateTestWorkspace("Workspace B", ownerB, router)
// Create private storage in workspace A
privateStorageA := createNewStorage(workspaceA.ID)
var savedPrivateStorageA Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+ownerA.Token,
*privateStorageA,
http.StatusOK,
&savedPrivateStorageA,
)
// Admin creates system storage in workspace B
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
systemStorageB := &Storage{
WorkspaceID: workspaceB.ID,
Type: StorageTypeLocal,
Name: "Test System Storage B " + uuid.New().String(),
IsSystem: true,
LocalStorage: &local_storage.LocalStorage{},
}
var savedSystemStorageB Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+admin.Token,
*systemStorageB,
http.StatusOK,
&savedSystemStorageB,
)
// Test: User from workspace A should see both private storage A and system storage B
var storagesForWorkspaceA []Storage
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/storages?workspace_id=%s", workspaceA.ID.String()),
"Bearer "+ownerA.Token,
http.StatusOK,
&storagesForWorkspaceA,
)
assert.GreaterOrEqual(t, len(storagesForWorkspaceA), 2)
foundPrivateA := false
foundSystemB := false
for _, s := range storagesForWorkspaceA {
if s.ID == savedPrivateStorageA.ID {
foundPrivateA = true
}
if s.ID == savedSystemStorageB.ID {
foundSystemB = true
}
}
assert.True(t, foundPrivateA, "User from workspace A should see private storage A")
assert.True(t, foundSystemB, "User from workspace A should see system storage B")
// Test: User from workspace B should see system storage B
var storagesForWorkspaceB []Storage
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/storages?workspace_id=%s", workspaceB.ID.String()),
"Bearer "+ownerB.Token,
http.StatusOK,
&storagesForWorkspaceB,
)
assert.GreaterOrEqual(t, len(storagesForWorkspaceB), 1)
foundSystemBInWorkspaceB := false
for _, s := range storagesForWorkspaceB {
if s.ID == savedSystemStorageB.ID {
foundSystemBInWorkspaceB = true
}
// Should NOT see private storage from workspace A
assert.NotEqual(
t,
savedPrivateStorageA.ID,
s.ID,
"User from workspace B should not see private storage from workspace A",
)
}
assert.True(t, foundSystemBInWorkspaceB, "User from workspace B should see system storage B")
// Test: Outsider (not in any workspace) cannot access storages
outsider := users_testing.CreateTestUser(users_enums.UserRoleMember)
test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/storages?workspace_id=%s", workspaceA.ID.String()),
"Bearer "+outsider.Token,
http.StatusForbidden,
)
// Cleanup
deleteStorage(t, router, savedPrivateStorageA.ID, ownerA.Token)
deleteStorage(t, router, savedSystemStorageB.ID, admin.Token)
workspaces_testing.RemoveTestWorkspace(workspaceA, router)
workspaces_testing.RemoveTestWorkspace(workspaceB, router)
}
func Test_GetSystemStorage_SensitiveDataHiddenForNonAdmin(t *testing.T) {
router := createRouter()
GetStorageService().SetStorageDatabaseCounter(&mockStorageDatabaseCounter{})
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", member, router)
// Admin creates system S3 storage with credentials
systemS3Storage := &Storage{
WorkspaceID: workspace.ID,
Type: StorageTypeS3,
Name: "Test System S3 Storage " + uuid.New().String(),
IsSystem: true,
S3Storage: &s3_storage.S3Storage{
S3Bucket: "test-system-bucket",
S3Region: "us-east-1",
S3AccessKey: "test-access-key-123",
S3SecretKey: "test-secret-key-456",
S3Endpoint: "https://s3.amazonaws.com",
},
}
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+admin.Token,
*systemS3Storage,
http.StatusOK,
&savedStorage,
)
assert.NotEmpty(t, savedStorage.ID)
assert.True(t, savedStorage.IsSystem)
// Test: Admin retrieves system storage - should see S3Storage object with hidden sensitive fields
var adminView Storage
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
"Bearer "+admin.Token,
http.StatusOK,
&adminView,
)
assert.NotNil(t, adminView.S3Storage, "Admin should see S3Storage object")
assert.Equal(t, "test-system-bucket", adminView.S3Storage.S3Bucket)
assert.Equal(t, "us-east-1", adminView.S3Storage.S3Region)
// Sensitive fields should be hidden (empty strings)
assert.Equal(
t,
"",
adminView.S3Storage.S3AccessKey,
"Admin should see hidden (empty) access key",
)
assert.Equal(
t,
"",
adminView.S3Storage.S3SecretKey,
"Admin should see hidden (empty) secret key",
)
// Test: Member retrieves system storage - should see storage but all specific data hidden
var memberView Storage
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
"Bearer "+member.Token,
http.StatusOK,
&memberView,
)
assert.Equal(t, savedStorage.ID, memberView.ID)
assert.Equal(t, savedStorage.Name, memberView.Name)
assert.True(t, memberView.IsSystem)
// All storage type objects should be nil for non-admin viewing system storage
assert.Nil(t, memberView.S3Storage, "Non-admin should not see S3Storage object")
assert.Nil(t, memberView.LocalStorage, "Non-admin should not see LocalStorage object")
assert.Nil(
t,
memberView.GoogleDriveStorage,
"Non-admin should not see GoogleDriveStorage object",
)
assert.Nil(t, memberView.NASStorage, "Non-admin should not see NASStorage object")
assert.Nil(t, memberView.AzureBlobStorage, "Non-admin should not see AzureBlobStorage object")
assert.Nil(t, memberView.FTPStorage, "Non-admin should not see FTPStorage object")
assert.Nil(t, memberView.SFTPStorage, "Non-admin should not see SFTPStorage object")
assert.Nil(t, memberView.RcloneStorage, "Non-admin should not see RcloneStorage object")
// Test: Member can also see system storage in GetStorages list
var storages []Storage
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/storages?workspace_id=%s", workspace.ID.String()),
"Bearer "+member.Token,
http.StatusOK,
&storages,
)
foundSystemStorage := false
for _, s := range storages {
if s.ID == savedStorage.ID {
foundSystemStorage = true
assert.True(t, s.IsSystem)
assert.Nil(t, s.S3Storage, "Non-admin should not see S3Storage in list")
}
}
assert.True(t, foundSystemStorage, "System storage should be in list")
// Cleanup
deleteStorage(t, router, savedStorage.ID, admin.Token)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_UserNotInWorkspace_CannotAccessStorages(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
outsider := users_testing.CreateTestUser(users_enums.UserRoleMember)
@@ -416,7 +1084,7 @@ func Test_UserNotInWorkspace_CannotAccessStorages(t *testing.T) {
http.StatusForbidden,
)
deleteStorage(t, router, savedStorage.ID, workspace.ID, owner.Token)
deleteStorage(t, router, savedStorage.ID, owner.Token)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
@@ -449,7 +1117,7 @@ func Test_CrossWorkspaceSecurity_CannotAccessStorageFromAnotherWorkspace(t *test
)
assert.Contains(t, string(response.Body), "insufficient permissions")
deleteStorage(t, router, savedStorage.ID, workspace1.ID, owner1.Token)
deleteStorage(t, router, savedStorage.ID, owner1.Token)
workspaces_testing.RemoveTestWorkspace(workspace1, router)
workspaces_testing.RemoveTestWorkspace(workspace2, router)
}
@@ -902,6 +1570,12 @@ func Test_StorageSensitiveDataLifecycle_AllTypes(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Skip Google Drive tests if external resources tests are disabled
if tc.storageType == StorageTypeGoogleDrive &&
config.GetEnv().IsSkipExternalResourcesTests {
t.Skip("Skipping Google Drive storage test: IS_SKIP_EXTERNAL_RESOURCES_TESTS=true")
}
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
@@ -979,6 +1653,10 @@ func Test_StorageSensitiveDataLifecycle_AllTypes(t *testing.T) {
&finalRetrieved,
)
tc.verifyHiddenData(t, &finalRetrieved)
// Cleanup
deleteStorage(t, router, createdStorage.ID, owner.Token)
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
@@ -1115,10 +1793,10 @@ func Test_TransferStorage_PermissionsEnforced(t *testing.T) {
)
assert.Equal(t, targetWorkspace.ID, retrievedStorage.WorkspaceID)
deleteStorage(t, router, savedStorage.ID, targetWorkspace.ID, targetOwner.Token)
deleteStorage(t, router, savedStorage.ID, targetOwner.Token)
} else {
assert.Contains(t, string(testResp.Body), "insufficient permissions")
deleteStorage(t, router, savedStorage.ID, sourceWorkspace.ID, sourceOwner.Token)
deleteStorage(t, router, savedStorage.ID, sourceOwner.Token)
}
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
@@ -1168,11 +1846,129 @@ func Test_TransferStorageNotManagableWorkspace_TransferFailed(t *testing.T) {
"insufficient permissions to manage storage in target workspace",
)
deleteStorage(t, router, savedStorage.ID, workspace1.ID, userA.Token)
deleteStorage(t, router, savedStorage.ID, userA.Token)
workspaces_testing.RemoveTestWorkspace(workspace1, router)
workspaces_testing.RemoveTestWorkspace(workspace2, router)
}
func Test_TransferSystemStorage_TransferBlocked(t *testing.T) {
router := createRouter()
GetStorageService().SetStorageDatabaseCounter(&mockStorageDatabaseCounter{})
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
workspaceA := workspaces_testing.CreateTestWorkspace("Workspace A", admin, router)
workspaceB := workspaces_testing.CreateTestWorkspace("Workspace B", admin, router)
// Admin creates system storage in workspace A
systemStorage := &Storage{
WorkspaceID: workspaceA.ID,
Type: StorageTypeLocal,
Name: "Test System Storage " + uuid.New().String(),
IsSystem: true,
LocalStorage: &local_storage.LocalStorage{},
}
var savedSystemStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+admin.Token,
*systemStorage,
http.StatusOK,
&savedSystemStorage,
)
// Admin attempts to transfer system storage to workspace B - should be blocked
transferRequest := TransferStorageRequest{
TargetWorkspaceID: workspaceB.ID,
}
testResp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/storages/%s/transfer", savedSystemStorage.ID.String()),
"Bearer "+admin.Token,
transferRequest,
http.StatusBadRequest,
)
assert.Contains(
t,
string(testResp.Body),
"system storage cannot be transferred",
"Transfer should fail with appropriate error message",
)
// Verify storage is still in workspace A
var retrievedStorage Storage
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", savedSystemStorage.ID.String()),
"Bearer "+admin.Token,
http.StatusOK,
&retrievedStorage,
)
assert.Equal(
t,
workspaceA.ID,
retrievedStorage.WorkspaceID,
"Storage should remain in workspace A",
)
// Test regression: Non-system storage can still be transferred
privateStorage := createNewStorage(workspaceA.ID)
var savedPrivateStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+admin.Token,
*privateStorage,
http.StatusOK,
&savedPrivateStorage,
)
privateTransferResp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/storages/%s/transfer", savedPrivateStorage.ID.String()),
"Bearer "+admin.Token,
transferRequest,
http.StatusOK,
)
assert.Contains(
t,
string(privateTransferResp.Body),
"transferred successfully",
"Private storage should be transferable",
)
// Verify private storage was transferred to workspace B
var transferredStorage Storage
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", savedPrivateStorage.ID.String()),
"Bearer "+admin.Token,
http.StatusOK,
&transferredStorage,
)
assert.Equal(
t,
workspaceB.ID,
transferredStorage.WorkspaceID,
"Private storage should be in workspace B",
)
// Cleanup
deleteStorage(t, router, savedSystemStorage.ID, admin.Token)
deleteStorage(t, router, savedPrivateStorage.ID, admin.Token)
workspaces_testing.RemoveTestWorkspace(workspaceA, router)
workspaces_testing.RemoveTestWorkspace(workspaceB, router)
}
func createRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
@@ -1205,12 +2001,13 @@ func verifyStorageData(t *testing.T, expected *Storage, actual *Storage) {
assert.Equal(t, expected.Name, actual.Name)
assert.Equal(t, expected.Type, actual.Type)
assert.Equal(t, expected.WorkspaceID, actual.WorkspaceID)
assert.Equal(t, expected.IsSystem, actual.IsSystem)
}
func deleteStorage(
t *testing.T,
router *gin.Engine,
storageID, workspaceID uuid.UUID,
storageID uuid.UUID,
token string,
) {
test_utils.MakeDeleteRequest(

View File

@@ -1,9 +1,13 @@
package storages
import (
"sync"
"sync/atomic"
audit_logs "databasus-backend/internal/features/audit_logs"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
)
var storageRepository = &StorageRepository{}
@@ -27,6 +31,21 @@ func GetStorageController() *StorageController {
return storageController
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(storageService)
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(storageService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -33,4 +33,13 @@ var (
ErrStorageHasOtherAttachedDatabasesCannotTransfer = errors.New(
"storage has other attached databases and cannot be transferred",
)
ErrSystemStorageCannotBeTransferred = errors.New(
"system storage cannot be transferred between workspaces",
)
ErrSystemStorageCannotBeMadePrivate = errors.New(
"system storage cannot be changed to non-system",
)
ErrLocalStorageNotAllowedInCloudMode = errors.New(
"local storage can only be managed by administrators in cloud mode",
)
)

View File

@@ -24,6 +24,7 @@ type Storage struct {
Type StorageType `json:"type" gorm:"column:type;not null;type:text"`
Name string `json:"name" gorm:"column:name;not null;type:text"`
LastSaveError *string `json:"lastSaveError" gorm:"column:last_save_error;type:text"`
IsSystem bool `json:"isSystem" gorm:"column:is_system;not null;default:false"`
// specific storage
LocalStorage *local_storage.LocalStorage `json:"localStorage" gorm:"foreignKey:StorageID"`
@@ -86,6 +87,17 @@ func (s *Storage) HideSensitiveData() {
s.getSpecificStorage().HideSensitiveData()
}
func (s *Storage) HideAllData() {
s.LocalStorage = nil
s.S3Storage = nil
s.GoogleDriveStorage = nil
s.NASStorage = nil
s.AzureBlobStorage = nil
s.FTPStorage = nil
s.SFTPStorage = nil
s.RcloneStorage = nil
}
func (s *Storage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
return s.getSpecificStorage().EncryptSensitiveData(encryptor)
}
@@ -93,6 +105,7 @@ func (s *Storage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) erro
func (s *Storage) Update(incoming *Storage) {
s.Name = incoming.Name
s.Type = incoming.Type
s.IsSystem = incoming.IsSystem
switch s.Type {
case StorageTypeLocal:

View File

@@ -113,7 +113,7 @@ func Test_Storage_BasicOperations(t *testing.T) {
name: "NASStorage",
storage: &nas_storage.NASStorage{
StorageID: uuid.New(),
Host: "localhost",
Host: config.GetEnv().TestLocalhost,
Port: nasPort,
Share: "backups",
Username: "testuser",
@@ -147,7 +147,7 @@ func Test_Storage_BasicOperations(t *testing.T) {
name: "FTPStorage",
storage: &ftp_storage.FTPStorage{
StorageID: uuid.New(),
Host: "localhost",
Host: config.GetEnv().TestLocalhost,
Port: ftpPort,
Username: "testuser",
Password: "testpassword",
@@ -159,7 +159,7 @@ func Test_Storage_BasicOperations(t *testing.T) {
name: "SFTPStorage",
storage: &sftp_storage.SFTPStorage{
StorageID: uuid.New(),
Host: "localhost",
Host: config.GetEnv().TestLocalhost,
Port: sftpPort,
Username: "testuser",
Password: "testpassword",
@@ -185,7 +185,9 @@ acl = private`, s3Container.accessKey, s3Container.secretKey, s3Container.endpoi
// Add Google Drive storage test only if environment variables are available
env := config.GetEnv()
if env.TestGoogleDriveClientID != "" && env.TestGoogleDriveClientSecret != "" &&
if env.IsSkipExternalResourcesTests {
t.Log("Skipping Google Drive storage test: IS_SKIP_EXTERNAL_RESOURCES_TESTS=true")
} else if env.TestGoogleDriveClientID != "" && env.TestGoogleDriveClientSecret != "" &&
env.TestGoogleDriveTokenJSON != "" {
testCases = append(testCases, struct {
name string
@@ -297,7 +299,7 @@ func setupS3Container(ctx context.Context) (*S3Container, error) {
secretKey := "testpassword"
bucketName := "test-bucket"
region := "us-east-1"
endpoint := fmt.Sprintf("127.0.0.1:%s", env.TestMinioPort)
endpoint := fmt.Sprintf("%s:%s", env.TestLocalhost, env.TestMinioPort)
// Create MinIO client and ensure bucket exists
minioClient, err := minio.New(endpoint, &minio.Options{
@@ -343,15 +345,21 @@ func setupAzuriteContainer(ctx context.Context) (*AzuriteContainer, error) {
accountName := "devstoreaccount1"
// this is real testing key for azurite, it's not a real key
accountKey := "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="
serviceURL := fmt.Sprintf("http://127.0.0.1:%s/%s", env.TestAzuriteBlobPort, accountName)
serviceURL := fmt.Sprintf(
"http://%s:%s/%s",
env.TestLocalhost,
env.TestAzuriteBlobPort,
accountName,
)
containerNameKey := "test-container-key"
containerNameStr := "test-container-connstr"
// Build explicit connection string for Azurite
connectionString := fmt.Sprintf(
"DefaultEndpointsProtocol=http;AccountName=%s;AccountKey=%s;BlobEndpoint=http://127.0.0.1:%s/%s",
"DefaultEndpointsProtocol=http;AccountName=%s;AccountKey=%s;BlobEndpoint=http://%s:%s/%s",
accountName,
accountKey,
env.TestLocalhost,
env.TestAzuriteBlobPort,
accountName,
)

View File

@@ -285,30 +285,30 @@ func (f *FTPStorage) ensureDirectory(conn *ftp.ServerConn, path string) error {
}
parts := strings.Split(path, "/")
currentPath := ""
currentDir, err := conn.CurrentDir()
if err != nil {
return fmt.Errorf("failed to get current directory: %w", err)
}
defer func() {
_ = conn.ChangeDir(currentDir)
}()
for _, part := range parts {
if part == "" || part == "." {
continue
}
if currentPath == "" {
currentPath = part
} else {
currentPath = currentPath + "/" + part
}
err := conn.ChangeDir(currentPath)
err := conn.ChangeDir(part)
if err != nil {
err = conn.MakeDir(currentPath)
err = conn.MakeDir(part)
if err != nil {
return fmt.Errorf("failed to create directory '%s': %w", currentPath, err)
return fmt.Errorf("failed to create directory '%s': %w", part, err)
}
err = conn.ChangeDir(part)
if err != nil {
return fmt.Errorf("failed to change into directory '%s': %w", part, err)
}
}
err = conn.ChangeDirToParent()
if err != nil {
return fmt.Errorf("failed to change to parent directory: %w", err)
}
}

View File

@@ -165,7 +165,7 @@ func (r *StorageRepository) FindByWorkspaceID(workspaceID uuid.UUID) ([]*Storage
Preload("FTPStorage").
Preload("SFTPStorage").
Preload("RcloneStorage").
Where("workspace_id = ?", workspaceID).
Where("workspace_id = ? OR is_system = TRUE", workspaceID).
Order("name ASC").
Find(&storages).Error; err != nil {
return nil, err

View File

@@ -3,7 +3,9 @@ package storages
import (
"fmt"
"databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
users_enums "databasus-backend/internal/features/users/enums"
users_models "databasus-backend/internal/features/users/models"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/encryption"
@@ -36,8 +38,18 @@ func (s *StorageService) SaveStorage(
return ErrInsufficientPermissionsToManageStorage
}
if config.GetEnv().IsCloud && storage.Type == StorageTypeLocal &&
user.Role != users_enums.UserRoleAdmin {
return ErrLocalStorageNotAllowedInCloudMode
}
isUpdate := storage.ID != uuid.Nil
if storage.IsSystem && user.Role != users_enums.UserRoleAdmin {
// only admin can manage system storage
return ErrInsufficientPermissionsToManageStorage
}
if isUpdate {
existingStorage, err := s.storageRepository.FindByID(storage.ID)
if err != nil {
@@ -48,6 +60,10 @@ func (s *StorageService) SaveStorage(
return ErrStorageDoesNotBelongToWorkspace
}
if existingStorage.IsSystem && !storage.IsSystem {
return ErrSystemStorageCannotBeMadePrivate
}
existingStorage.Update(storage)
if err := existingStorage.EncryptSensitiveData(s.fieldEncryptor); err != nil {
@@ -111,6 +127,11 @@ func (s *StorageService) DeleteStorage(
return ErrInsufficientPermissionsToManageStorage
}
if storage.IsSystem && user.Role != users_enums.UserRoleAdmin {
// only admin can manage system storage
return ErrInsufficientPermissionsToManageStorage
}
attachedDatabasesIDs, err := s.storageDatabaseCounter.GetStorageAttachedDatabasesIDs(storage.ID)
if err != nil {
return err
@@ -142,16 +163,22 @@ func (s *StorageService) GetStorage(
return nil, err
}
canView, _, err := s.workspaceService.CanUserAccessWorkspace(storage.WorkspaceID, user)
if err != nil {
return nil, err
}
if !canView {
return nil, ErrInsufficientPermissionsToViewStorage
if !storage.IsSystem {
canView, _, err := s.workspaceService.CanUserAccessWorkspace(storage.WorkspaceID, user)
if err != nil {
return nil, err
}
if !canView {
return nil, ErrInsufficientPermissionsToViewStorage
}
}
storage.HideSensitiveData()
if storage.IsSystem && user.Role != users_enums.UserRoleAdmin {
storage.HideAllData()
}
return storage, nil
}
@@ -174,6 +201,10 @@ func (s *StorageService) GetStorages(
for _, storage := range storages {
storage.HideSensitiveData()
if storage.IsSystem && user.Role != users_enums.UserRoleAdmin {
storage.HideAllData()
}
}
return storages, nil
@@ -213,8 +244,14 @@ func (s *StorageService) TestStorageConnection(
}
func (s *StorageService) TestStorageConnectionDirect(
user *users_models.User,
storage *Storage,
) error {
if config.GetEnv().IsCloud && storage.Type == StorageTypeLocal &&
user.Role != users_enums.UserRoleAdmin {
return ErrLocalStorageNotAllowedInCloudMode
}
var usingStorage *Storage
if storage.ID != uuid.Nil {
@@ -258,6 +295,10 @@ func (s *StorageService) TransferStorageToWorkspace(
return err
}
if existingStorage.IsSystem {
return ErrSystemStorageCannotBeTransferred
}
canManageSource, err := s.workspaceService.CanUserManageDBs(existingStorage.WorkspaceID, user)
if err != nil {
return err

View File

@@ -23,6 +23,18 @@ func (c *HealthcheckController) RegisterRoutes(router *gin.RouterGroup) {
// @Failure 503 {object} HealthcheckResponse
// @Router /system/health [get]
func (c *HealthcheckController) CheckHealth(ctx *gin.Context) {
// Allow unrestricted CORS for health check endpoint
// This enables monitoring tools from any origin to check system health
ctx.Header("Access-Control-Allow-Origin", "*")
ctx.Header("Access-Control-Allow-Methods", "GET, OPTIONS")
ctx.Header("Access-Control-Allow-Headers", "Content-Type")
// Handle preflight OPTIONS request
if ctx.Request.Method == "OPTIONS" {
ctx.AbortWithStatus(http.StatusNoContent)
return
}
err := c.healthcheckService.IsHealthy()
if err == nil {

View File

@@ -1,11 +1,14 @@
package system_healthcheck
import (
"context"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups/backuping"
"databasus-backend/internal/features/disk"
"databasus-backend/internal/storage"
cache_utils "databasus-backend/internal/util/cache"
"errors"
"time"
)
type HealthcheckService struct {
@@ -15,6 +18,20 @@ type HealthcheckService struct {
}
func (s *HealthcheckService) IsHealthy() error {
return s.performHealthCheck()
}
func (s *HealthcheckService) performHealthCheck() error {
// Check if cache is available with PING
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
client := cache_utils.GetValkeyClient()
pingResult := client.Do(ctx, client.B().Ping().Build())
if pingResult.Error() != nil {
return errors.New("cannot connect to valkey")
}
diskUsage, err := s.diskService.GetDiskUsage()
if err != nil {
return errors.New("cannot get disk usage")
@@ -40,6 +57,7 @@ func (s *HealthcheckService) IsHealthy() error {
if config.GetEnv().IsProcessingNode {
if !s.backuperNode.IsBackuperRunning() {
return errors.New("backuper node is not running for more than 5 minutes")
}
}

View File

@@ -2,9 +2,11 @@ package task_cancellation
import (
"context"
"sync"
"sync/atomic"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/logger"
"sync"
"github.com/google/uuid"
)
@@ -20,6 +22,21 @@ func GetTaskCancelManager() *TaskCancelManager {
return taskCancelManager
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
taskCancelManager.StartSubscription()
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
taskCancelManager.StartSubscription()
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -0,0 +1,159 @@
package features
import (
"context"
"fmt"
"sync"
"testing"
"time"
"databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
"databasus-backend/internal/features/backups/backups/backuping"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/restores"
"databasus-backend/internal/features/restores/restoring"
"databasus-backend/internal/features/storages"
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
)
// Test_SetupDependencies_CalledTwice_LogsWarning verifies SetupDependencies is idempotent
func Test_SetupDependencies_CalledTwice_LogsWarning(t *testing.T) {
// Call each SetupDependencies twice - should not panic, only log warnings
audit_logs.SetupDependencies()
audit_logs.SetupDependencies()
backups.SetupDependencies()
backups.SetupDependencies()
backups_config.SetupDependencies()
backups_config.SetupDependencies()
databases.SetupDependencies()
databases.SetupDependencies()
healthcheck_config.SetupDependencies()
healthcheck_config.SetupDependencies()
notifiers.SetupDependencies()
notifiers.SetupDependencies()
restores.SetupDependencies()
restores.SetupDependencies()
storages.SetupDependencies()
storages.SetupDependencies()
task_cancellation.SetupDependencies()
task_cancellation.SetupDependencies()
// If we reach here without panic, test passes
t.Log("All SetupDependencies calls completed successfully (idempotent)")
}
// Test_SetupDependencies_ConcurrentCalls_Safe verifies thread safety
func Test_SetupDependencies_ConcurrentCalls_Safe(t *testing.T) {
var wg sync.WaitGroup
// Call SetupDependencies concurrently from 10 goroutines
for range 10 {
wg.Add(1)
go func() {
defer wg.Done()
audit_logs.SetupDependencies()
}()
}
wg.Wait()
t.Log("Concurrent SetupDependencies calls completed successfully")
}
// Test_BackgroundService_Run_CalledTwice_Panics verifies Run() panics on duplicate calls
func Test_BackgroundService_Run_CalledTwice_Panics(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Create a test background service
backgroundService := audit_logs.GetAuditLogBackgroundService()
// Start first Run() in goroutine
go func() {
backgroundService.Run(ctx)
}()
// Give first call time to initialize
time.Sleep(100 * time.Millisecond)
// Second call should panic
defer func() {
if r := recover(); r != nil {
expectedMsg := "*audit_logs.AuditLogBackgroundService.Run() called multiple times"
panicMsg := fmt.Sprintf("%v", r)
if panicMsg == expectedMsg {
t.Logf("Successfully caught panic: %v", r)
} else {
t.Errorf("Expected panic message '%s', got '%s'", expectedMsg, panicMsg)
}
} else {
t.Error("Expected panic on second Run() call, but did not panic")
}
}()
backgroundService.Run(ctx)
}
// Test_BackupsScheduler_Run_CalledTwice_Panics verifies scheduler panics on duplicate calls
func Test_BackupsScheduler_Run_CalledTwice_Panics(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
scheduler := backuping.GetBackupsScheduler()
// Start first Run() in goroutine
go func() {
scheduler.Run(ctx)
}()
// Give first call time to initialize
time.Sleep(100 * time.Millisecond)
// Second call should panic
defer func() {
if r := recover(); r != nil {
t.Logf("Successfully caught panic: %v", r)
} else {
t.Error("Expected panic on second Run() call, but did not panic")
}
}()
scheduler.Run(ctx)
}
// Test_RestoresScheduler_Run_CalledTwice_Panics verifies restore scheduler panics on duplicate calls
func Test_RestoresScheduler_Run_CalledTwice_Panics(t *testing.T) {
ctx := t.Context()
scheduler := restoring.GetRestoresScheduler()
// Start first Run() in goroutine
go func() {
scheduler.Run(ctx)
}()
// Give first call time to initialize
time.Sleep(100 * time.Millisecond)
// Second call should panic
defer func() {
if r := recover(); r != nil {
t.Logf("Successfully caught panic: %v", r)
} else {
t.Error("Expected panic on second Run() call, but did not panic")
}
}()
scheduler.Run(ctx)
}

View File

@@ -605,7 +605,7 @@ func connectToMariadbContainer(
dbName := "testdb"
password := "rootpassword"
username := "root"
host := "127.0.0.1"
host := config.GetEnv().TestLocalhost
portInt, err := strconv.Atoi(port)
if err != nil {

View File

@@ -550,7 +550,7 @@ func connectToMongodbContainer(
password := "rootpassword"
username := "root"
authDatabase := "admin"
host := "127.0.0.1"
host := config.GetEnv().TestLocalhost
portInt, err := strconv.Atoi(port)
if err != nil {
@@ -558,8 +558,13 @@ func connectToMongodbContainer(
}
uri := fmt.Sprintf(
"mongodb://%s:%s@%s:%d/%s?authSource=%s",
username, password, host, portInt, dbName, authDatabase,
"mongodb://%s:%s@%s:%d/%s?authSource=%s&serverSelectionTimeoutMS=5000&connectTimeoutMS=5000",
username,
password,
host,
portInt,
dbName,
authDatabase,
)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)

Some files were not shown because too many files have changed in this diff Show More