mirror of
https://github.com/databasus/databasus.git
synced 2026-04-06 00:32:03 +02:00
Compare commits
122 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8e3d1e5bff | ||
|
|
349e7f0ee8 | ||
|
|
3a274e135b | ||
|
|
61e937bc2a | ||
|
|
f67919fe1a | ||
|
|
91ee5966d8 | ||
|
|
d77d7d69a3 | ||
|
|
fc88b730d5 | ||
|
|
1f1d80245f | ||
|
|
16a29cf458 | ||
|
|
43e04500ac | ||
|
|
cee3022f85 | ||
|
|
f46d92c480 | ||
|
|
10677238d7 | ||
|
|
2553203fcf | ||
|
|
7b05bd8000 | ||
|
|
8d45728f73 | ||
|
|
c70ad82c95 | ||
|
|
e4bc34d319 | ||
|
|
257ae85da7 | ||
|
|
b42c820bb2 | ||
|
|
da5c13fb11 | ||
|
|
35180360e5 | ||
|
|
e4f6cd7a5d | ||
|
|
d7b8e6d56a | ||
|
|
6016f23fb2 | ||
|
|
e7c4ee8f6f | ||
|
|
a75702a01b | ||
|
|
81a21eb907 | ||
|
|
33d6bf0147 | ||
|
|
6eb53bb07b | ||
|
|
6ac04270b9 | ||
|
|
b0510d7c21 | ||
|
|
dc5f271882 | ||
|
|
8f718771c9 | ||
|
|
d8eea05dca | ||
|
|
b2a94274d7 | ||
|
|
77c2712ebb | ||
|
|
a9dc29f82c | ||
|
|
c934a45dca | ||
|
|
d4acdf2826 | ||
|
|
49753c4fc0 | ||
|
|
c6aed6b36d | ||
|
|
3060b4266a | ||
|
|
ebeb597f17 | ||
|
|
4783784325 | ||
|
|
bd41433bdb | ||
|
|
a9073787d2 | ||
|
|
0890bf8f09 | ||
|
|
f8c11e8802 | ||
|
|
e798d82fc1 | ||
|
|
81a01585ee | ||
|
|
a8465c1a10 | ||
|
|
a9e5db70f6 | ||
|
|
7a47be6ca6 | ||
|
|
16be3db0c6 | ||
|
|
744e51d1e1 | ||
|
|
b3af75d430 | ||
|
|
6f7320abeb | ||
|
|
a1655d35a6 | ||
|
|
9b6e801184 | ||
|
|
105777ab6f | ||
|
|
3a1a88d5cf | ||
|
|
699ca16814 | ||
|
|
26f3cf233a | ||
|
|
3d8372e9f6 | ||
|
|
b46f11804d | ||
|
|
4676361688 | ||
|
|
de3679cadf | ||
|
|
8f03a30af2 | ||
|
|
356529c58a | ||
|
|
e7eed056f7 | ||
|
|
6084cdc954 | ||
|
|
c50bcc57b1 | ||
|
|
ea76300ed7 | ||
|
|
9b413e4076 | ||
|
|
f91cb260f2 | ||
|
|
8f37a8082f | ||
|
|
5cf7614772 | ||
|
|
ae27f74c2e | ||
|
|
9457516bb9 | ||
|
|
a36fc5bf8c | ||
|
|
03ada5806d | ||
|
|
a6675390e5 | ||
|
|
af2f978876 | ||
|
|
04e7eba5c5 | ||
|
|
520165541d | ||
|
|
5b556bc161 | ||
|
|
0952a15ec5 | ||
|
|
1afb3aa3ff | ||
|
|
19b92e5f74 | ||
|
|
d4763f26b2 | ||
|
|
0e389ba16b | ||
|
|
594a3294c6 | ||
|
|
4e4a323cf1 | ||
|
|
7d9ecf697b | ||
|
|
755c420157 | ||
|
|
ff73627287 | ||
|
|
9c9ab00ace | ||
|
|
7366e21a1a | ||
|
|
a327d1aa57 | ||
|
|
f152b16ea3 | ||
|
|
85dbe80d3d | ||
|
|
edf4028fd1 | ||
|
|
8d85c45a90 | ||
|
|
d9c176d19a | ||
|
|
7a6f72a456 | ||
|
|
9a1471b88b | ||
|
|
386ea1d708 | ||
|
|
a4b23936ee | ||
|
|
b36aa9d48b | ||
|
|
13cb8e5bd2 | ||
|
|
2db4b6e075 | ||
|
|
f2b0b2bf1f | ||
|
|
7142ce295e | ||
|
|
04621b9b2d | ||
|
|
bd329a68cf | ||
|
|
f957abc9db | ||
|
|
c0fd6be1a9 | ||
|
|
c39bd34d5e | ||
|
|
27bec15a29 | ||
|
|
d98baa0656 |
224
.github/workflows/ci-release.yml
vendored
224
.github/workflows/ci-release.yml
vendored
@@ -9,15 +9,26 @@ on:
|
||||
|
||||
jobs:
|
||||
lint-backend:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
container:
|
||||
image: golang:1.24.9
|
||||
volumes:
|
||||
- /runner-cache/go-pkg:/go/pkg/mod
|
||||
- /runner-cache/go-build:/root/.cache/go-build
|
||||
- /runner-cache/golangci-lint:/root/.cache/golangci-lint
|
||||
- /runner-cache/apt-archives:/var/cache/apt/archives
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.24.9"
|
||||
- name: Configure Git for container
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Download Go modules
|
||||
run: |
|
||||
cd backend
|
||||
go mod download
|
||||
|
||||
- name: Install golangci-lint
|
||||
run: |
|
||||
@@ -70,6 +81,11 @@ jobs:
|
||||
cd frontend
|
||||
npm run lint
|
||||
|
||||
- name: Build frontend
|
||||
run: |
|
||||
cd frontend
|
||||
npm run build
|
||||
|
||||
test-frontend:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint-frontend]
|
||||
@@ -93,34 +109,32 @@ jobs:
|
||||
npm run test
|
||||
|
||||
test-backend:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
needs: [lint-backend]
|
||||
container:
|
||||
image: golang:1.24.9
|
||||
options: --privileged -v /var/run/docker.sock:/var/run/docker.sock --add-host=host.docker.internal:host-gateway
|
||||
volumes:
|
||||
- /runner-cache/go-pkg:/go/pkg/mod
|
||||
- /runner-cache/go-build:/root/.cache/go-build
|
||||
- /runner-cache/apt-archives:/var/cache/apt/archives
|
||||
steps:
|
||||
- name: Free up disk space
|
||||
- name: Install Docker CLI
|
||||
run: |
|
||||
echo "Disk space before cleanup:"
|
||||
df -h
|
||||
# Remove unnecessary pre-installed software
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo rm -rf /usr/local/share/boost
|
||||
sudo rm -rf /usr/share/swift
|
||||
# Clean apt cache
|
||||
sudo apt-get clean
|
||||
# Clean docker images (if any pre-installed)
|
||||
docker system prune -af --volumes || true
|
||||
echo "Disk space after cleanup:"
|
||||
df -h
|
||||
apt-get update -qq
|
||||
apt-get install -y -qq docker.io docker-compose netcat-openbsd wget
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.24.9"
|
||||
- name: Configure Git for container
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Download Go modules
|
||||
run: |
|
||||
cd backend
|
||||
go mod download
|
||||
|
||||
- name: Create .env file for testing
|
||||
run: |
|
||||
@@ -132,14 +146,16 @@ jobs:
|
||||
DEV_DB_PASSWORD=Q1234567
|
||||
#app
|
||||
ENV_MODE=development
|
||||
# db
|
||||
DATABASE_DSN=host=localhost user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
|
||||
DATABASE_URL=postgres://postgres:Q1234567@localhost:5437/databasus?sslmode=disable
|
||||
# db - using 172.17.0.1 to access host from container
|
||||
DATABASE_DSN=host=172.17.0.1 user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
|
||||
DATABASE_URL=postgres://postgres:Q1234567@172.17.0.1:5437/databasus?sslmode=disable
|
||||
# migrations
|
||||
GOOSE_DRIVER=postgres
|
||||
GOOSE_DBSTRING=postgres://postgres:Q1234567@localhost:5437/databasus?sslmode=disable
|
||||
GOOSE_DBSTRING=postgres://postgres:Q1234567@172.17.0.1:5437/databasus?sslmode=disable
|
||||
GOOSE_MIGRATION_DIR=./migrations
|
||||
# testing
|
||||
# testing
|
||||
TEST_LOCALHOST=172.17.0.1
|
||||
IS_SKIP_EXTERNAL_RESOURCES_TESTS=true
|
||||
# to get Google Drive env variables: add storage in UI and copy data from added storage here
|
||||
TEST_GOOGLE_DRIVE_CLIENT_ID=${{ secrets.TEST_GOOGLE_DRIVE_CLIENT_ID }}
|
||||
TEST_GOOGLE_DRIVE_CLIENT_SECRET=${{ secrets.TEST_GOOGLE_DRIVE_CLIENT_SECRET }}
|
||||
@@ -197,12 +213,14 @@ jobs:
|
||||
TEST_MONGODB_60_PORT=27060
|
||||
TEST_MONGODB_70_PORT=27070
|
||||
TEST_MONGODB_82_PORT=27082
|
||||
# Valkey (cache)
|
||||
VALKEY_HOST=localhost
|
||||
# Valkey (cache) - using 172.17.0.1
|
||||
VALKEY_HOST=172.17.0.1
|
||||
VALKEY_PORT=6379
|
||||
VALKEY_USERNAME=
|
||||
VALKEY_PASSWORD=
|
||||
VALKEY_IS_SSL=false
|
||||
# Host for test databases (container -> host)
|
||||
TEST_DB_HOST=172.17.0.1
|
||||
EOF
|
||||
|
||||
- name: Start test containers
|
||||
@@ -220,25 +238,25 @@ jobs:
|
||||
timeout 60 bash -c 'until docker exec dev-valkey valkey-cli ping 2>/dev/null | grep -q PONG; do sleep 2; done'
|
||||
echo "Valkey is ready!"
|
||||
|
||||
# Wait for test databases
|
||||
timeout 60 bash -c 'until nc -z localhost 5000; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5001; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5002; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5003; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5004; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5005; do sleep 2; done'
|
||||
# Wait for test databases (using 172.17.0.1 from container)
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5000; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5001; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5002; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5003; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5004; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5005; do sleep 2; done'
|
||||
|
||||
# Wait for MinIO
|
||||
timeout 60 bash -c 'until nc -z localhost 9000; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 9000; do sleep 2; done'
|
||||
|
||||
# Wait for Azurite
|
||||
timeout 60 bash -c 'until nc -z localhost 10000; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 10000; do sleep 2; done'
|
||||
|
||||
# Wait for FTP
|
||||
timeout 60 bash -c 'until nc -z localhost 7007; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 7007; do sleep 2; done'
|
||||
|
||||
# Wait for SFTP
|
||||
timeout 60 bash -c 'until nc -z localhost 7008; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 7008; do sleep 2; done'
|
||||
|
||||
# Wait for MySQL containers
|
||||
echo "Waiting for MySQL 5.7..."
|
||||
@@ -297,63 +315,63 @@ jobs:
|
||||
mkdir -p databasus-data/backups
|
||||
mkdir -p databasus-data/temp
|
||||
|
||||
- name: Install MySQL dependencies
|
||||
- name: Install database client dependencies
|
||||
run: |
|
||||
sudo apt-get update -qq
|
||||
sudo apt-get install -y -qq libncurses6
|
||||
sudo ln -sf /usr/lib/x86_64-linux-gnu/libncurses.so.6 /usr/lib/x86_64-linux-gnu/libncurses.so.5
|
||||
sudo ln -sf /usr/lib/x86_64-linux-gnu/libtinfo.so.6 /usr/lib/x86_64-linux-gnu/libtinfo.so.5
|
||||
apt-get update -qq
|
||||
apt-get install -y -qq libncurses6 libpq5
|
||||
ln -sf /usr/lib/x86_64-linux-gnu/libncurses.so.6 /usr/lib/x86_64-linux-gnu/libncurses.so.5 || true
|
||||
ln -sf /usr/lib/x86_64-linux-gnu/libtinfo.so.6 /usr/lib/x86_64-linux-gnu/libtinfo.so.5 || true
|
||||
|
||||
- name: Setup PostgreSQL, MySQL and MariaDB client tools from pre-built assets
|
||||
run: |
|
||||
cd backend/tools
|
||||
|
||||
|
||||
# Create directory structure
|
||||
mkdir -p postgresql mysql mariadb mongodb/bin
|
||||
|
||||
|
||||
# Copy PostgreSQL client tools (12-18) from pre-built assets
|
||||
for version in 12 13 14 15 16 17 18; do
|
||||
mkdir -p postgresql/postgresql-$version
|
||||
cp -r ../../assets/tools/x64/postgresql/postgresql-$version/bin postgresql/postgresql-$version/
|
||||
done
|
||||
|
||||
|
||||
# Copy MySQL client tools (5.7, 8.0, 8.4, 9) from pre-built assets
|
||||
for version in 5.7 8.0 8.4 9; do
|
||||
mkdir -p mysql/mysql-$version
|
||||
cp -r ../../assets/tools/x64/mysql/mysql-$version/bin mysql/mysql-$version/
|
||||
done
|
||||
|
||||
|
||||
# Copy MariaDB client tools (10.6, 12.1) from pre-built assets
|
||||
for version in 10.6 12.1; do
|
||||
mkdir -p mariadb/mariadb-$version
|
||||
cp -r ../../assets/tools/x64/mariadb/mariadb-$version/bin mariadb/mariadb-$version/
|
||||
done
|
||||
|
||||
|
||||
# Make all binaries executable
|
||||
chmod +x postgresql/*/bin/*
|
||||
chmod +x mysql/*/bin/*
|
||||
chmod +x mariadb/*/bin/*
|
||||
|
||||
|
||||
echo "Pre-built client tools setup complete"
|
||||
|
||||
- name: Install MongoDB Database Tools
|
||||
run: |
|
||||
cd backend/tools
|
||||
|
||||
|
||||
# MongoDB Database Tools must be downloaded (not in pre-built assets)
|
||||
# They are backward compatible - single version supports all servers (4.0-8.0)
|
||||
MONGODB_TOOLS_URL="https://fastdl.mongodb.org/tools/db/mongodb-database-tools-debian12-x86_64-100.10.0.deb"
|
||||
|
||||
|
||||
echo "Downloading MongoDB Database Tools..."
|
||||
wget -q "$MONGODB_TOOLS_URL" -O /tmp/mongodb-database-tools.deb
|
||||
|
||||
|
||||
echo "Installing MongoDB Database Tools..."
|
||||
sudo dpkg -i /tmp/mongodb-database-tools.deb || sudo apt-get install -f -y --no-install-recommends
|
||||
|
||||
dpkg -i /tmp/mongodb-database-tools.deb || apt-get install -f -y --no-install-recommends
|
||||
|
||||
# Create symlinks to tools directory
|
||||
ln -sf /usr/bin/mongodump mongodb/bin/mongodump
|
||||
ln -sf /usr/bin/mongorestore mongodb/bin/mongorestore
|
||||
|
||||
|
||||
rm -f /tmp/mongodb-database-tools.deb
|
||||
echo "MongoDB Database Tools installed successfully"
|
||||
|
||||
@@ -401,10 +419,28 @@ jobs:
|
||||
if: always()
|
||||
run: |
|
||||
cd backend
|
||||
# Stop and remove containers (keeping images for next run)
|
||||
docker compose -f docker-compose.yml.example down -v
|
||||
|
||||
# Clean up all data directories created by docker-compose
|
||||
echo "Cleaning up data directories..."
|
||||
rm -rf pgdata || true
|
||||
rm -rf valkey-data || true
|
||||
rm -rf mysqldata || true
|
||||
rm -rf mariadbdata || true
|
||||
rm -rf temp/nas || true
|
||||
rm -rf databasus-data || true
|
||||
|
||||
# Also clean root-level databasus-data if exists
|
||||
cd ..
|
||||
rm -rf databasus-data || true
|
||||
|
||||
echo "Cleanup complete"
|
||||
|
||||
determine-version:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
container:
|
||||
image: node:20
|
||||
needs: [test-backend, test-frontend]
|
||||
if: ${{ github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message, '[skip-release]') }}
|
||||
outputs:
|
||||
@@ -417,10 +453,9 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "20"
|
||||
- name: Configure Git for container
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Install semver
|
||||
run: npm install -g semver
|
||||
@@ -434,6 +469,7 @@ jobs:
|
||||
|
||||
- name: Analyze commits and determine version bump
|
||||
id: version_bump
|
||||
shell: bash
|
||||
run: |
|
||||
CURRENT_VERSION="${{ steps.current_version.outputs.current_version }}"
|
||||
LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
|
||||
@@ -453,7 +489,7 @@ jobs:
|
||||
HAS_FIX=false
|
||||
HAS_BREAKING=false
|
||||
|
||||
# Analyze each commit
|
||||
# Analyze each commit - USE PROCESS SUBSTITUTION to avoid subshell variable scope issues
|
||||
while IFS= read -r commit; do
|
||||
if [[ "$commit" =~ ^FEATURE ]]; then
|
||||
HAS_FEATURE=true
|
||||
@@ -471,7 +507,7 @@ jobs:
|
||||
HAS_BREAKING=true
|
||||
echo "Found BREAKING CHANGE: $commit"
|
||||
fi
|
||||
done <<< "$COMMITS"
|
||||
done < <(printf '%s\n' "$COMMITS")
|
||||
|
||||
# Determine version bump
|
||||
if [ "$HAS_BREAKING" = true ]; then
|
||||
@@ -497,10 +533,15 @@ jobs:
|
||||
fi
|
||||
|
||||
build-only:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
needs: [test-backend, test-frontend]
|
||||
if: ${{ github.ref == 'refs/heads/main' && contains(github.event.head_commit.message, '[skip-release]') }}
|
||||
steps:
|
||||
- name: Clean workspace
|
||||
run: |
|
||||
sudo rm -rf "$GITHUB_WORKSPACE"/* || true
|
||||
sudo rm -rf "$GITHUB_WORKSPACE"/.* || true
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -529,12 +570,17 @@ jobs:
|
||||
databasus/databasus:${{ github.sha }}
|
||||
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
needs: [determine-version]
|
||||
if: ${{ needs.determine-version.outputs.should_release == 'true' }}
|
||||
permissions:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Clean workspace
|
||||
run: |
|
||||
sudo rm -rf "$GITHUB_WORKSPACE"/* || true
|
||||
sudo rm -rf "$GITHUB_WORKSPACE"/.* || true
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -564,21 +610,33 @@ jobs:
|
||||
databasus/databasus:${{ github.sha }}
|
||||
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
container:
|
||||
image: node:20
|
||||
needs: [determine-version, build-and-push]
|
||||
if: ${{ needs.determine-version.outputs.should_release == 'true' }}
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Clean workspace
|
||||
run: |
|
||||
rm -rf "$GITHUB_WORKSPACE"/* || true
|
||||
rm -rf "$GITHUB_WORKSPACE"/.* || true
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Configure Git for container
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Generate changelog
|
||||
id: changelog
|
||||
shell: bash
|
||||
run: |
|
||||
NEW_VERSION="${{ needs.determine-version.outputs.new_version }}"
|
||||
LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
|
||||
@@ -598,6 +656,7 @@ jobs:
|
||||
FIXES=""
|
||||
REFACTORS=""
|
||||
|
||||
# USE PROCESS SUBSTITUTION to avoid subshell variable scope issues
|
||||
while IFS= read -r line; do
|
||||
if [ -n "$line" ]; then
|
||||
COMMIT_MSG=$(echo "$line" | cut -d'|' -f1)
|
||||
@@ -631,7 +690,7 @@ jobs:
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
done <<< "$COMMITS"
|
||||
done < <(printf '%s\n' "$COMMITS")
|
||||
|
||||
# Build changelog sections
|
||||
if [ -n "$FEATURES" ]; then
|
||||
@@ -670,16 +729,33 @@ jobs:
|
||||
prerelease: false
|
||||
|
||||
publish-helm-chart:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
container:
|
||||
image: alpine:3.19
|
||||
volumes:
|
||||
- /runner-cache/apk-cache:/etc/apk/cache
|
||||
needs: [determine-version, build-and-push]
|
||||
if: ${{ needs.determine-version.outputs.should_release == 'true' }}
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
steps:
|
||||
- name: Clean workspace
|
||||
run: |
|
||||
rm -rf "$GITHUB_WORKSPACE"/* || true
|
||||
rm -rf "$GITHUB_WORKSPACE"/.* || true
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apk add --no-cache git bash curl
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Configure Git for container
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@v4
|
||||
with:
|
||||
@@ -701,4 +777,4 @@ jobs:
|
||||
- name: Push Helm chart to GHCR
|
||||
run: |
|
||||
VERSION="${{ needs.determine-version.outputs.new_version }}"
|
||||
helm push databasus-${VERSION}.tgz oci://ghcr.io/databasus/charts
|
||||
helm push databasus-${VERSION}.tgz oci://ghcr.io/databasus/charts
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -1,3 +1,4 @@
|
||||
ansible/
|
||||
postgresus_data/
|
||||
postgresus-data/
|
||||
databasus-data/
|
||||
@@ -9,4 +10,5 @@ node_modules/
|
||||
/articles
|
||||
|
||||
.DS_Store
|
||||
/scripts
|
||||
/scripts
|
||||
.vscode/settings.json
|
||||
|
||||
@@ -18,6 +18,13 @@ repos:
|
||||
files: ^frontend/.*\.(ts|tsx|js|jsx)$
|
||||
pass_filenames: false
|
||||
|
||||
- id: frontend-build
|
||||
name: Frontend Build
|
||||
entry: bash -c "cd frontend && npm run build"
|
||||
language: system
|
||||
files: ^frontend/.*\.(ts|tsx|js|jsx|json|css)$
|
||||
pass_filenames: false
|
||||
|
||||
# Backend checks
|
||||
- repo: local
|
||||
hooks:
|
||||
|
||||
509
AGENTS.md
509
AGENTS.md
@@ -1,30 +1,94 @@
|
||||
# Agent Rules and Guidelines
|
||||
|
||||
This document contains all coding standards, conventions and best practices recommended for the Databasus project.
|
||||
This document contains all coding standards, conventions and best practices recommended for the TgTaps project.
|
||||
This is NOT a strict set of rules, but a set of recommendations to help you write better code.
|
||||
|
||||
---
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Backend Guidelines](#backend-guidelines)
|
||||
- [Code Style](#code-style)
|
||||
- [Engineering philosophy](#engineering-philosophy)
|
||||
- [Backend guidelines](#backend-guidelines)
|
||||
- [Code style](#code-style)
|
||||
- [Boolean naming](#boolean-naming)
|
||||
- [Add reasonable new lines between logical statements](#add-reasonable-new-lines-between-logical-statements)
|
||||
- [Comments](#comments)
|
||||
- [Controllers](#controllers)
|
||||
- [Dependency Injection (DI)](#dependency-injection-di)
|
||||
- [Dependency injection (DI)](#dependency-injection-di)
|
||||
- [Migrations](#migrations)
|
||||
- [Refactoring](#refactoring)
|
||||
- [Testing](#testing)
|
||||
- [Time Handling](#time-handling)
|
||||
- [CRUD Examples](#crud-examples)
|
||||
- [Frontend Guidelines](#frontend-guidelines)
|
||||
- [React Component Structure](#react-component-structure)
|
||||
- [Time handling](#time-handling)
|
||||
- [CRUD examples](#crud-examples)
|
||||
- [Frontend guidelines](#frontend-guidelines)
|
||||
- [React component structure](#react-component-structure)
|
||||
|
||||
---
|
||||
|
||||
## Backend Guidelines
|
||||
## Engineering philosophy
|
||||
|
||||
### Code Style
|
||||
**Think like a skeptical senior engineer and code reviewer. Don't just do what was asked—also think about what should have been asked.**
|
||||
|
||||
⚠️ **Balance vigilance with pragmatism:** Catch real issues, not theoretical ones. Don't let perfect be the enemy of good.
|
||||
|
||||
### Task context assessment:
|
||||
|
||||
**First, assess the task scope:**
|
||||
|
||||
- **Trivial** (typos, formatting, simple field adds): Apply directly with minimal analysis
|
||||
- **Standard** (CRUD, typical features): Brief assumption check, proceed
|
||||
- **Complex** (architecture, security, performance-critical): Full analysis required
|
||||
- **Unclear** (ambiguous requirements): Always clarify assumptions first
|
||||
|
||||
### For non-trivial tasks:
|
||||
|
||||
1. **Restate the objective and list assumptions** (explicit + implicit)
|
||||
- If any assumption is shaky, call it out clearly
|
||||
- Distinguish between what's specified and what you're inferring
|
||||
|
||||
2. **Propose appropriate solutions:**
|
||||
- For complex tasks: 2–3 viable approaches (including a simpler baseline)
|
||||
- Recommend one with clear tradeoffs
|
||||
- Consider: complexity, maintainability, performance, future extensibility
|
||||
|
||||
3. **Identify risks proactively:**
|
||||
- Edge cases and boundary conditions
|
||||
- Security/privacy pitfalls
|
||||
- Performance risks and scalability concerns
|
||||
- Operational concerns (deployment, observability, rollback, monitoring)
|
||||
|
||||
4. **Handle ambiguity:**
|
||||
- If requirements are ambiguous, make a reasonable default and proceed
|
||||
- Clearly label your assumptions
|
||||
- Document what would change under alternative assumptions
|
||||
|
||||
5. **Deliver quality:**
|
||||
- Provide a solution that is correct, testable, and maintainable
|
||||
- Include minimal tests or validation steps
|
||||
- Follow project testing philosophy: prefer controller tests over unit tests
|
||||
- Follow all project guidelines from this document
|
||||
|
||||
6. **Self-review before finalizing:**
|
||||
- Ask: "What could go wrong?"
|
||||
- Patch the answer accordingly
|
||||
- Verify edge cases are handled
|
||||
|
||||
### Application guidelines:
|
||||
|
||||
**Scale your response to the task:**
|
||||
|
||||
- **Trivial changes:** Steps 5-6 only (deliver quality + self-review)
|
||||
- **Standard features:** Steps 1, 5-6 (restate + deliver + review)
|
||||
- **Complex/risky changes:** All steps 1-6
|
||||
- **Ambiguous requests:** Steps 1, 4 mandatory
|
||||
|
||||
**Be proportionally thorough—brief for simple tasks, comprehensive for risky ones. Avoid analysis paralysis.**
|
||||
|
||||
---
|
||||
|
||||
## Backend guidelines
|
||||
|
||||
### Code style
|
||||
|
||||
**Always place private methods to the bottom of file**
|
||||
|
||||
@@ -32,7 +96,7 @@ This rule applies to ALL Go files including tests, services, controllers, reposi
|
||||
|
||||
In Go, exported (public) functions/methods start with uppercase letters, while unexported (private) ones start with lowercase letters.
|
||||
|
||||
#### Structure Order:
|
||||
#### Structure order:
|
||||
|
||||
1. Type definitions and constants
|
||||
2. Public methods/functions (uppercase)
|
||||
@@ -165,7 +229,7 @@ func (c *ProjectController) extractProjectID(ctx *gin.Context) uuid.UUID {
|
||||
}
|
||||
```
|
||||
|
||||
#### Key Points:
|
||||
#### Key points:
|
||||
|
||||
- **Exported/Public** = starts with uppercase letter (CreateUser, GetProject)
|
||||
- **Unexported/Private** = starts with lowercase letter (validateUser, handleError)
|
||||
@@ -175,6 +239,227 @@ func (c *ProjectController) extractProjectID(ctx *gin.Context) uuid.UUID {
|
||||
|
||||
---
|
||||
|
||||
### Boolean naming
|
||||
|
||||
**Always prefix boolean variables with verbs like `is`, `has`, `was`, `should`, `can`, etc.**
|
||||
|
||||
This makes the code more readable and clearly indicates that the variable represents a true/false state.
|
||||
|
||||
#### Good examples:
|
||||
|
||||
```go
|
||||
type User struct {
|
||||
IsActive bool
|
||||
IsVerified bool
|
||||
HasAccess bool
|
||||
WasNotified bool
|
||||
}
|
||||
|
||||
type BackupConfig struct {
|
||||
IsEnabled bool
|
||||
ShouldCompress bool
|
||||
CanRetry bool
|
||||
}
|
||||
|
||||
// Variables
|
||||
isInProgress := true
|
||||
wasCompleted := false
|
||||
hasPermission := checkPermissions()
|
||||
```
|
||||
|
||||
#### Bad examples:
|
||||
|
||||
```go
|
||||
type User struct {
|
||||
Active bool // Should be: IsActive
|
||||
Verified bool // Should be: IsVerified
|
||||
Access bool // Should be: HasAccess
|
||||
}
|
||||
|
||||
type BackupConfig struct {
|
||||
Enabled bool // Should be: IsEnabled
|
||||
Compress bool // Should be: ShouldCompress
|
||||
Retry bool // Should be: CanRetry
|
||||
}
|
||||
|
||||
// Variables
|
||||
inProgress := true // Should be: isInProgress
|
||||
completed := false // Should be: wasCompleted
|
||||
permission := true // Should be: hasPermission
|
||||
```
|
||||
|
||||
#### Common boolean prefixes:
|
||||
|
||||
- **is** - current state (IsActive, IsValid, IsEnabled)
|
||||
- **has** - possession or presence (HasAccess, HasPermission, HasError)
|
||||
- **was** - past state (WasCompleted, WasNotified, WasDeleted)
|
||||
- **should** - intention or recommendation (ShouldRetry, ShouldCompress)
|
||||
- **can** - capability or permission (CanRetry, CanDelete, CanEdit)
|
||||
- **will** - future state (WillExpire, WillRetry)
|
||||
|
||||
---
|
||||
|
||||
### Add reasonable new lines between logical statements
|
||||
|
||||
**Add blank lines between logical blocks to improve code readability.**
|
||||
|
||||
Separate different logical operations within a function with blank lines. This makes the code flow clearer and helps identify distinct steps in the logic.
|
||||
|
||||
#### Guidelines:
|
||||
|
||||
- Add blank line before final `return` statement
|
||||
- Add blank line after variable declarations before using them
|
||||
- Add blank line between error handling and subsequent logic
|
||||
- Add blank line between different logical operations
|
||||
|
||||
#### Bad example (without spacing):
|
||||
|
||||
```go
|
||||
func (t *Task) BeforeSave(tx *gorm.DB) error {
|
||||
if len(t.Messages) > 0 {
|
||||
messagesBytes, err := json.Marshal(t.Messages)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.MessagesJSON = string(messagesBytes)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Task) AfterFind(tx *gorm.DB) error {
|
||||
if t.MessagesJSON != "" {
|
||||
var messages []onewin_dto.TaskCompletionMessage
|
||||
if err := json.Unmarshal([]byte(t.MessagesJSON), &messages); err != nil {
|
||||
return err
|
||||
}
|
||||
t.Messages = messages
|
||||
}
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
#### Good example (with proper spacing):
|
||||
|
||||
```go
|
||||
func (t *Task) BeforeSave(tx *gorm.DB) error {
|
||||
if len(t.Messages) > 0 {
|
||||
messagesBytes, err := json.Marshal(t.Messages)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.MessagesJSON = string(messagesBytes)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Task) AfterFind(tx *gorm.DB) error {
|
||||
if t.MessagesJSON != "" {
|
||||
var messages []onewin_dto.TaskCompletionMessage
|
||||
if err := json.Unmarshal([]byte(t.MessagesJSON), &messages); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.Messages = messages
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
#### More examples:
|
||||
|
||||
**Service method with multiple operations:**
|
||||
|
||||
```go
|
||||
func (s *UserService) CreateUser(request *CreateUserRequest) (*User, error) {
|
||||
// Validate input
|
||||
if err := s.validateUserRequest(request); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create user entity
|
||||
user := &User{
|
||||
ID: uuid.New(),
|
||||
Name: request.Name,
|
||||
Email: request.Email,
|
||||
}
|
||||
|
||||
// Save to database
|
||||
if err := s.repository.Create(user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Send notification
|
||||
s.notificationService.SendWelcomeEmail(user.Email)
|
||||
|
||||
return user, nil
|
||||
}
|
||||
```
|
||||
|
||||
**Repository method with query building:**
|
||||
|
||||
```go
|
||||
func (r *Repository) GetFiltered(filters *Filters) ([]*Entity, error) {
|
||||
query := storage.GetDb().Model(&Entity{})
|
||||
|
||||
if filters.Status != "" {
|
||||
query = query.Where("status = ?", filters.Status)
|
||||
}
|
||||
|
||||
if filters.CreatedAfter != nil {
|
||||
query = query.Where("created_at > ?", filters.CreatedAfter)
|
||||
}
|
||||
|
||||
var entities []*Entity
|
||||
if err := query.Find(&entities).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return entities, nil
|
||||
}
|
||||
```
|
||||
|
||||
**Repository method with error handling:**
|
||||
|
||||
Bad (without spacing):
|
||||
|
||||
```go
|
||||
func (r *Repository) FindById(id uuid.UUID) (*models.Task, error) {
|
||||
var task models.Task
|
||||
result := storage.GetDb().Where("id = ?", id).First(&task)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("task not found")
|
||||
}
|
||||
return nil, result.Error
|
||||
}
|
||||
return &task, nil
|
||||
}
|
||||
```
|
||||
|
||||
Good (with proper spacing):
|
||||
|
||||
```go
|
||||
func (r *Repository) FindById(id uuid.UUID) (*models.Task, error) {
|
||||
var task models.Task
|
||||
|
||||
result := storage.GetDb().Where("id = ?", id).First(&task)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("task not found")
|
||||
}
|
||||
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
return &task, nil
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Comments
|
||||
|
||||
#### Guidelines
|
||||
@@ -183,13 +468,14 @@ func (c *ProjectController) extractProjectID(ctx *gin.Context) uuid.UUID {
|
||||
2. **Functions and variables should have meaningful names** - Code should be self-documenting
|
||||
3. **Comments for unclear code only** - Only add comments when code logic isn't immediately clear
|
||||
|
||||
#### Key Principles:
|
||||
#### Key principles:
|
||||
|
||||
- **Code should tell a story** - Use descriptive variable and function names
|
||||
- **Comments explain WHY, not WHAT** - The code shows what happens, comments explain business logic or complex decisions
|
||||
- **Prefer refactoring over commenting** - If code needs explaining, consider making it clearer instead
|
||||
- **API documentation is required** - Swagger comments for all HTTP endpoints are mandatory
|
||||
- **Complex algorithms deserve comments** - Mathematical formulas, business rules, or non-obvious optimizations
|
||||
- **Do not write summary sections in .md files unless directly requested** - Avoid adding "Summary" or "Conclusion" sections at the end of documentation files unless the user explicitly asks for them
|
||||
|
||||
#### Example of useless comments:
|
||||
|
||||
@@ -221,7 +507,7 @@ func CreateValidLogItems(count int, uniqueID string) []logs_receiving.LogItemReq
|
||||
|
||||
### Controllers
|
||||
|
||||
#### Controller Guidelines:
|
||||
#### Controller guidelines:
|
||||
|
||||
1. **When we write controller:**
|
||||
- We combine all routes to single controller
|
||||
@@ -353,7 +639,7 @@ func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
|
||||
|
||||
---
|
||||
|
||||
### Dependency Injection (DI)
|
||||
### Dependency injection (DI)
|
||||
|
||||
For DI files use **implicit fields declaration styles** (especially for controllers, services, repositories, use cases, etc., not simple data structures).
|
||||
|
||||
@@ -381,7 +667,7 @@ var orderController = &OrderController{
|
||||
|
||||
**This is needed to avoid forgetting to update DI style when we add new dependency.**
|
||||
|
||||
#### Force Such Usage
|
||||
#### Force such usage
|
||||
|
||||
Please force such usage if file look like this (see some services\controllers\repos definitions and getters):
|
||||
|
||||
@@ -427,6 +713,134 @@ func GetOrderRepository() *repositories.OrderRepository {
|
||||
}
|
||||
```
|
||||
|
||||
#### SetupDependencies() pattern
|
||||
|
||||
**All `SetupDependencies()` functions must use sync.Once to ensure idempotent execution.**
|
||||
|
||||
This pattern allows `SetupDependencies()` to be safely called multiple times (especially in tests) while ensuring the actual setup logic executes only once.
|
||||
|
||||
**Implementation pattern:**
|
||||
|
||||
```go
|
||||
package feature
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
// Initialize dependencies here
|
||||
someService.SetDependency(otherService)
|
||||
anotherService.AddListener(listener)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Why this pattern:**
|
||||
|
||||
- **Tests can call multiple times**: Test setup often calls `SetupDependencies()` multiple times without issues
|
||||
- **Thread-safe**: Works correctly with concurrent calls (nanoseconds or seconds apart)
|
||||
- **Idempotent**: Subsequent calls are safe, only log warning
|
||||
- **No panics**: Does not break tests or production code on multiple calls
|
||||
|
||||
**Key Points:**
|
||||
|
||||
1. Check `isSetup.Load()` **before** calling `Do()` to detect previous executions
|
||||
2. Set `isSetup.Store(true)` **inside** the `Do()` closure after setup completes
|
||||
3. Log warning if already setup (helps identify unnecessary duplicate calls)
|
||||
4. All setup logic must be inside the `Do()` closure
|
||||
|
||||
---
|
||||
|
||||
### Background services
|
||||
|
||||
**All background service `Run()` methods must panic if called multiple times to prevent corrupted states.**
|
||||
|
||||
Background services run infinite loops and must never be started twice on the same instance. Multiple calls indicate a serious bug that would cause duplicate goroutines, resource leaks, and data corruption.
|
||||
|
||||
**Implementation pattern:**
|
||||
|
||||
```go
|
||||
package feature
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type BackgroundService struct {
|
||||
// ... existing fields ...
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *BackgroundService) Run(ctx context.Context) {
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
|
||||
// Existing infinite loop logic
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.doWork()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Why panic instead of warning:**
|
||||
|
||||
- **Prevents corruption**: Multiple `Run()` calls would create duplicate goroutines consuming resources
|
||||
- **Fails fast**: Catches critical bugs immediately in tests and production
|
||||
- **Clear indication**: Panic clearly indicates a serious programming error
|
||||
- **Applies everywhere**: Same protection in tests and production
|
||||
|
||||
**When this applies:**
|
||||
|
||||
- All background services with infinite loops
|
||||
- Registry services (BackupNodesRegistry, RestoreNodesRegistry)
|
||||
- Scheduler services (BackupsScheduler, RestoresScheduler)
|
||||
- Worker nodes (BackuperNode, RestorerNode)
|
||||
- Cleanup services (AuditLogBackgroundService, DownloadTokenBackgroundService)
|
||||
|
||||
**Key Points:**
|
||||
|
||||
1. Check `hasRun.Load()` **before** calling `Do()` to detect previous executions
|
||||
2. Set `hasRun.Store(true)` **inside** the `Do()` closure before starting work
|
||||
3. **Always panic** if already run (never just log warning)
|
||||
4. All run logic must be inside the `Do()` closure
|
||||
5. This pattern is **thread-safe** for any timing (concurrent or sequential calls)
|
||||
|
||||
---
|
||||
|
||||
### Migrations
|
||||
@@ -477,14 +891,14 @@ You can shortify, make more readable, improve code quality, etc. Common logic ca
|
||||
|
||||
**After writing tests, always launch them and verify that they pass.**
|
||||
|
||||
#### Test Naming Format
|
||||
#### Test naming format
|
||||
|
||||
Use these naming patterns:
|
||||
|
||||
- `Test_WhatWeDo_WhatWeExpect`
|
||||
- `Test_WhatWeDo_WhichConditions_WhatWeExpect`
|
||||
|
||||
#### Examples from Real Codebase:
|
||||
#### Examples from real codebase:
|
||||
|
||||
- `Test_CreateApiKey_WhenUserIsProjectOwner_ApiKeyCreated`
|
||||
- `Test_UpdateProject_WhenUserIsProjectAdmin_ProjectUpdated`
|
||||
@@ -492,22 +906,22 @@ Use these naming patterns:
|
||||
- `Test_GetProjectAuditLogs_WithDifferentUserRoles_EnforcesPermissionsCorrectly`
|
||||
- `Test_ProjectLifecycleE2E_CompletesSuccessfully`
|
||||
|
||||
#### Testing Philosophy
|
||||
#### Testing philosophy
|
||||
|
||||
**Prefer Controllers Over Unit Tests:**
|
||||
**Prefer controllers over unit tests:**
|
||||
|
||||
- Test through HTTP endpoints via controllers whenever possible
|
||||
- Avoid testing repositories, services in isolation - test via API instead
|
||||
- Only use unit tests for complex model logic when no API exists
|
||||
- Name test files `controller_test.go` or `service_test.go`, not `integration_test.go`
|
||||
|
||||
**Extract Common Logic to Testing Utilities:**
|
||||
**Extract common logic to testing utilities:**
|
||||
|
||||
- Create `testing.go` or `testing/testing.go` files for shared test utilities
|
||||
- Extract router creation, user setup, models creation helpers (in API, not just structs creation)
|
||||
- Reuse common patterns across different test files
|
||||
|
||||
**Refactor Existing Tests:**
|
||||
**Refactor existing tests:**
|
||||
|
||||
- When working with existing tests, always look for opportunities to refactor and improve
|
||||
- Extract repetitive setup code to common utilities
|
||||
@@ -516,7 +930,44 @@ Use these naming patterns:
|
||||
- Consolidate similar test patterns across different test files
|
||||
- Make tests more readable and maintainable for other developers
|
||||
|
||||
#### Testing Utilities Structure
|
||||
**Clean up test data:**
|
||||
|
||||
- If the feature supports cleanup operations (DELETE endpoints, cleanup methods), use them in tests
|
||||
- Clean up resources after test execution to avoid test data pollution
|
||||
- Use `defer` statements or explicit cleanup calls at the end of tests
|
||||
- Prioritize using API methods for cleanup (not direct database deletion)
|
||||
- Examples:
|
||||
- CRUD features: delete created records via DELETE endpoint
|
||||
- File uploads: remove uploaded files
|
||||
- Background jobs: stop schedulers or cancel running tasks
|
||||
- Skip cleanup only when:
|
||||
- Tests run in isolated transactions that auto-rollback
|
||||
- Cleanup endpoint doesn't exist yet
|
||||
- Test explicitly validates failure scenarios where cleanup isn't possible
|
||||
|
||||
**Example:**
|
||||
|
||||
```go
|
||||
func Test_BackupLifecycle_CreateAndDelete(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test", owner)
|
||||
|
||||
// Create backup config
|
||||
config := createBackupConfig(t, router, workspace.ID, owner.Token)
|
||||
|
||||
// Cleanup at end of test
|
||||
defer deleteBackupConfig(t, router, workspace.ID, config.ID, owner.Token)
|
||||
|
||||
// Test operations...
|
||||
triggerBackup(t, router, workspace.ID, config.ID, owner.Token)
|
||||
|
||||
// Verify backup was created
|
||||
backups := getBackups(t, router, workspace.ID, owner.Token)
|
||||
assert.NotEmpty(t, backups)
|
||||
}
|
||||
```
|
||||
|
||||
#### Testing utilities structure
|
||||
|
||||
**Create `testing.go` or `testing/testing.go` files with common utilities:**
|
||||
|
||||
@@ -552,7 +1003,7 @@ func AddMemberToProject(project *projects_models.Project, member *users_dto.Sign
|
||||
}
|
||||
```
|
||||
|
||||
#### Controller Test Examples
|
||||
#### Controller test examples
|
||||
|
||||
**Permission-based testing:**
|
||||
|
||||
@@ -619,7 +1070,7 @@ func Test_ProjectLifecycleE2E_CompletesSuccessfully(t *testing.T) {
|
||||
|
||||
---
|
||||
|
||||
### Time Handling
|
||||
### Time handling
|
||||
|
||||
**Always use `time.Now().UTC()` instead of `time.Now()`**
|
||||
|
||||
@@ -627,7 +1078,7 @@ This ensures consistent timezone handling across the application.
|
||||
|
||||
---
|
||||
|
||||
### CRUD Examples
|
||||
### CRUD examples
|
||||
|
||||
This is an example of complete CRUD implementation structure:
|
||||
|
||||
@@ -1291,9 +1742,9 @@ func createTimedLog(db *gorm.DB, userID *uuid.UUID, message string, createdAt ti
|
||||
|
||||
---
|
||||
|
||||
## Frontend Guidelines
|
||||
## Frontend guidelines
|
||||
|
||||
### React Component Structure
|
||||
### React component structure
|
||||
|
||||
Write React components with the following structure:
|
||||
|
||||
@@ -1327,7 +1778,7 @@ export const ReactComponent = ({ someValue }: Props): JSX.Element => {
|
||||
}
|
||||
```
|
||||
|
||||
#### Structure Order:
|
||||
#### Structure order:
|
||||
|
||||
1. **Props interface** - Define component props
|
||||
2. **Helper functions** (outside component) - Pure utility functions
|
||||
|
||||
61
Dockerfile
61
Dockerfile
@@ -251,6 +251,37 @@ fi
|
||||
# PostgreSQL 17 binary paths
|
||||
PG_BIN="/usr/lib/postgresql/17/bin"
|
||||
|
||||
# Generate runtime configuration for frontend
|
||||
echo "Generating runtime configuration..."
|
||||
|
||||
# Detect if email is configured (both SMTP_HOST and DATABASUS_URL must be set)
|
||||
if [ -n "\${SMTP_HOST:-}" ] && [ -n "\${DATABASUS_URL:-}" ]; then
|
||||
IS_EMAIL_CONFIGURED="true"
|
||||
else
|
||||
IS_EMAIL_CONFIGURED="false"
|
||||
fi
|
||||
|
||||
cat > /app/ui/build/runtime-config.js <<JSEOF
|
||||
// Runtime configuration injected at container startup
|
||||
// This file is generated dynamically and should not be edited manually
|
||||
window.__RUNTIME_CONFIG__ = {
|
||||
IS_CLOUD: '\${IS_CLOUD:-false}',
|
||||
GITHUB_CLIENT_ID: '\${GITHUB_CLIENT_ID:-}',
|
||||
GOOGLE_CLIENT_ID: '\${GOOGLE_CLIENT_ID:-}',
|
||||
IS_EMAIL_CONFIGURED: '\$IS_EMAIL_CONFIGURED',
|
||||
CLOUDFLARE_TURNSTILE_SITE_KEY: '\${CLOUDFLARE_TURNSTILE_SITE_KEY:-}'
|
||||
};
|
||||
JSEOF
|
||||
|
||||
# Inject analytics script if provided (only if not already injected)
|
||||
if [ -n "\${ANALYTICS_SCRIPT:-}" ]; then
|
||||
if ! grep -q "rybbit.databasus.com" /app/ui/build/index.html 2>/dev/null; then
|
||||
echo "Injecting analytics script..."
|
||||
sed -i "s#</head># \${ANALYTICS_SCRIPT}\\
|
||||
</head>#" /app/ui/build/index.html
|
||||
fi
|
||||
fi
|
||||
|
||||
# Ensure proper ownership of data directory
|
||||
echo "Setting up data directory permissions..."
|
||||
mkdir -p /databasus-data/pgdata
|
||||
@@ -372,9 +403,37 @@ SQL
|
||||
|
||||
# Start the main application
|
||||
echo "Starting Databasus application..."
|
||||
|
||||
# Check and warn about external database/Valkey usage
|
||||
if [ -n "\${DANGEROUS_EXTERNAL_DATABASE_DSN:-}" ]; then
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "WARNING: Using external database"
|
||||
echo "=========================================="
|
||||
echo "DANGEROUS_EXTERNAL_DATABASE_DSN is set."
|
||||
echo "Application will connect to external PostgreSQL instead of internal instance."
|
||||
echo "Internal PostgreSQL is still running in the background."
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
fi
|
||||
|
||||
if [ -n "\${DANGEROUS_VALKEY_HOST:-}" ]; then
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "WARNING: Using external Valkey"
|
||||
echo "=========================================="
|
||||
echo "DANGEROUS_VALKEY_HOST is set."
|
||||
echo "Application will connect to external Valkey instead of internal instance."
|
||||
echo "Internal Valkey is still running in the background."
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
fi
|
||||
|
||||
exec ./main
|
||||
EOF
|
||||
|
||||
LABEL org.opencontainers.image.source="https://github.com/databasus/databasus"
|
||||
|
||||
RUN chmod +x /app/start.sh
|
||||
|
||||
EXPOSE 4005
|
||||
@@ -383,4 +442,4 @@ EXPOSE 4005
|
||||
VOLUME ["/databasus-data"]
|
||||
|
||||
ENTRYPOINT ["/app/start.sh"]
|
||||
CMD []
|
||||
CMD []
|
||||
|
||||
60
README.md
60
README.md
@@ -11,7 +11,7 @@
|
||||
[](https://www.mongodb.com/)
|
||||
<br />
|
||||
[](LICENSE)
|
||||
[](https://hub.docker.com/r/rostislavdugin/postgresus)
|
||||
[](https://hub.docker.com/r/databasus/databasus)
|
||||
[](https://github.com/databasus/databasus)
|
||||
[](https://github.com/databasus/databasus)
|
||||
[](https://github.com/databasus/databasus)
|
||||
@@ -31,8 +31,6 @@
|
||||
<img src="assets/dashboard-dark.svg" alt="Databasus Dark Dashboard" width="800" style="margin-bottom: 10px;"/>
|
||||
|
||||
<img src="assets/dashboard.svg" alt="Databasus Dashboard" width="800"/>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
---
|
||||
@@ -52,6 +50,13 @@
|
||||
- **Precise timing**: run backups at specific times (e.g., 4 AM during low traffic)
|
||||
- **Smart compression**: 4-8x space savings with balanced compression (~20% overhead)
|
||||
|
||||
### 🗑️ **Retention policies**
|
||||
|
||||
- **Time period**: Keep backups for a fixed duration (e.g., 7 days, 3 months, 1 year)
|
||||
- **Count**: Keep a fixed number of the most recent backups (e.g., last 30)
|
||||
- **GFS (Grandfather-Father-Son)**: Layered retention — keep hourly, daily, weekly, monthly and yearly backups independently for fine-grained long-term history (enterprises requirement)
|
||||
- **Size limits**: Set per-backup and total storage size caps to control storage гыфпу
|
||||
|
||||
### 🗄️ **Multiple storage destinations** <a href="https://databasus.com/storages">(view supported)</a>
|
||||
|
||||
- **Local storage**: Keep backups on your VPS/server
|
||||
@@ -71,6 +76,8 @@
|
||||
- **Encryption for secrets**: Any sensitive data is encrypted and never exposed, even in logs or error messages
|
||||
- **Read-only user**: Databasus uses a read-only user by default for backups and never stores anything that can modify your data
|
||||
|
||||
It is also important for Databasus that you are able to decrypt and restore backups from storages (local, S3, etc.) without Databasus itself. To do so, read our guide on [how to recover directly from storage](https://databasus.com/how-to-recover-without-databasus). We avoid "vendor lock-in" even to open source tool!
|
||||
|
||||
### 👥 **Suitable for teams** <a href="https://databasus.com/access-management">(docs)</a>
|
||||
|
||||
- **Workspaces**: Group databases, notifiers and storages for different projects or teams
|
||||
@@ -220,8 +227,9 @@ For more options (NodePort, TLS, HTTPRoute for Gateway API), see the [Helm chart
|
||||
3. **Configure schedule**: Choose from hourly, daily, weekly, monthly or cron intervals
|
||||
4. **Set database connection**: Enter your database credentials and connection details
|
||||
5. **Choose storage**: Select where to store your backups (local, S3, Google Drive, etc.)
|
||||
6. **Add notifications** (optional): Configure email, Telegram, Slack, or webhook notifications
|
||||
7. **Save and start**: Databasus will validate settings and begin the backup schedule
|
||||
6. **Configure retention policy**: Choose time period, count or GFS to control how long backups are kept
|
||||
7. **Add notifications** (optional): Configure email, Telegram, Slack, or webhook notifications
|
||||
8. **Save and start**: Databasus will validate settings and begin the backup schedule
|
||||
|
||||
### 🔑 Resetting password <a href="https://databasus.com/password">(docs)</a>
|
||||
|
||||
@@ -233,56 +241,22 @@ docker exec -it databasus ./main --new-password="YourNewSecurePassword123" --ema
|
||||
|
||||
Replace `admin` with the actual email address of the user whose password you want to reset.
|
||||
|
||||
### 💾 Backuping Databasus itself
|
||||
|
||||
After installation, it is also recommended to <a href="https://databasus.com/faq/#backup-databasus">backup your Databasus itself</a> or, at least, to copy secret key used for encryption (30 seconds is needed). So you are able to restore from your encrypted backups if you lose access to the server with Databasus or it is corrupted.
|
||||
|
||||
---
|
||||
|
||||
## 📝 License
|
||||
|
||||
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details
|
||||
|
||||
---
|
||||
|
||||
## 🤝 Contributing
|
||||
|
||||
Contributions are welcome! Read the <a href="https://databasus.com/contribute">contributing guide</a> for more details, priorities and rules. If you want to contribute but don't know where to start, message me on Telegram [@rostislav_dugin](https://t.me/rostislav_dugin)
|
||||
|
||||
Also you can join our large community of developers, DBAs and DevOps engineers on Telegram [@databasus_community](https://t.me/databasus_community).
|
||||
|
||||
--
|
||||
|
||||
## 📖 Migration guide
|
||||
|
||||
Databasus is the new name for Postgresus. You can stay with latest version of Postgresus if you wish. If you want to migrate - follow installation steps for Databasus itself.
|
||||
|
||||
Just renaming an image is not enough as Postgresus and Databasus use different data folders and internal database naming.
|
||||
|
||||
You can put a new Databasus image with updated volume near the old Postgresus and run it (stop Postgresus before):
|
||||
|
||||
```
|
||||
services:
|
||||
databasus:
|
||||
container_name: databasus
|
||||
image: databasus/databasus:latest
|
||||
ports:
|
||||
- "4005:4005"
|
||||
volumes:
|
||||
- ./databasus-data:/databasus-data
|
||||
restart: unless-stopped
|
||||
```
|
||||
|
||||
Then manually move databases from Postgresus to Databasus.
|
||||
|
||||
### Why was Postgresus renamed to Databasus?
|
||||
|
||||
Databasus has been developed since 2023. It was internal tool to backup production and home projects databases. In start of 2025 it was released as open source project on GitHub. By the end of 2025 it became popular and the time for renaming has come in December 2025.
|
||||
|
||||
It was an important step for the project to grow. Actually, there are a couple of reasons:
|
||||
|
||||
1. Postgresus is no longer a little tool that just adds UI for pg_dump for little projects. It became a tool both for individual users, DevOps, DBAs, teams, companies and even large enterprises. Tens of thousands of users use Postgresus every day. Postgresus grew into a reliable backup management tool. Initial positioning is no longer suitable: the project is not just a UI wrapper, it's a solid backup management system now (despite it's still easy to use).
|
||||
|
||||
2. New databases are supported: although the primary focus is PostgreSQL (with 100% support in the most efficient way) and always will be, Databasus added support for MySQL, MariaDB and MongoDB. Later more databases will be supported.
|
||||
|
||||
3. Trademark issue: "postgres" is a trademark of PostgreSQL Inc. and cannot be used in the project name. So for safety and legal reasons, we had to rename the project.
|
||||
|
||||
## AI disclaimer
|
||||
|
||||
There have been questions about AI usage in project development in issues and discussions. As the project focuses on security, reliability and production usage, it's important to explain how AI is used in the development process.
|
||||
|
||||
@@ -6,6 +6,14 @@ DEV_DB_PASSWORD=Q1234567
|
||||
ENV_MODE=development
|
||||
# logging
|
||||
SHOW_DB_INSTALLATION_VERIFICATION_LOGS=true
|
||||
VICTORIA_LOGS_URL=http://localhost:9428
|
||||
VICTORIA_LOGS_PASSWORD=devpassword
|
||||
# tests
|
||||
TEST_LOCALHOST=localhost
|
||||
IS_SKIP_EXTERNAL_RESOURCES_TESTS=false
|
||||
# cloudflare turnstile
|
||||
CLOUDFLARE_TURNSTILE_SITE_KEY=
|
||||
CLOUDFLARE_TURNSTILE_SECRET_KEY=
|
||||
# db
|
||||
DATABASE_DSN=host=dev-db user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
|
||||
DATABASE_URL=postgres://postgres:Q1234567@dev-db:5437/databasus?sslmode=disable
|
||||
|
||||
3
backend/.gitignore
vendored
3
backend/.gitignore
vendored
@@ -18,4 +18,5 @@ pgdata-for-restore/
|
||||
temp/
|
||||
cmd.exe
|
||||
temp/
|
||||
valkey-data/
|
||||
valkey-data/
|
||||
victoria-logs-data/
|
||||
@@ -25,10 +25,10 @@ import (
|
||||
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/restores"
|
||||
"databasus-backend/internal/features/restores/restoring"
|
||||
"databasus-backend/internal/features/storages"
|
||||
system_healthcheck "databasus-backend/internal/features/system/healthcheck"
|
||||
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
task_registry "databasus-backend/internal/features/tasks/registry"
|
||||
users_controllers "databasus-backend/internal/features/users/controllers"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
@@ -185,6 +185,9 @@ func startServerWithGracefulShutdown(log *slog.Logger, app *gin.Engine) {
|
||||
<-quit
|
||||
log.Info("Shutdown signal received")
|
||||
|
||||
// Gracefully shutdown VictoriaLogs writer
|
||||
logger.ShutdownVictoriaLogs(5 * time.Second)
|
||||
|
||||
// The context is used to inform the server it has 10 seconds to finish
|
||||
// the request it is currently handling
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -272,8 +275,12 @@ func runBackgroundTasks(log *slog.Logger) {
|
||||
backuping.GetBackupsScheduler().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "backup cleaner background service", func() {
|
||||
backuping.GetBackupCleaner().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "restore background service", func() {
|
||||
restores.GetRestoreBackgroundService().Run(ctx)
|
||||
restoring.GetRestoresScheduler().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "healthcheck attempt background service", func() {
|
||||
@@ -288,21 +295,29 @@ func runBackgroundTasks(log *slog.Logger) {
|
||||
backups_download.GetDownloadTokenBackgroundService().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "task nodes registry background service", func() {
|
||||
task_registry.GetTaskNodesRegistry().Run(ctx)
|
||||
go runWithPanicLogging(log, "backup nodes registry background service", func() {
|
||||
backuping.GetBackupNodesRegistry().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "restore nodes registry background service", func() {
|
||||
restoring.GetRestoreNodesRegistry().Run(ctx)
|
||||
})
|
||||
} else {
|
||||
log.Info("Skipping primary node tasks as not primary node")
|
||||
}
|
||||
|
||||
if config.GetEnv().IsBackupNode {
|
||||
if config.GetEnv().IsProcessingNode {
|
||||
log.Info("Starting backup node background tasks...")
|
||||
|
||||
go runWithPanicLogging(log, "backup node", func() {
|
||||
backuping.GetBackuperNode().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "restore node", func() {
|
||||
restoring.GetRestorerNode().Run(ctx)
|
||||
})
|
||||
} else {
|
||||
log.Info("Skipping backup node tasks as not backup node")
|
||||
log.Info("Skipping backup/restore node tasks as not backup node")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -34,6 +34,20 @@ services:
|
||||
retries: 5
|
||||
start_period: 20s
|
||||
|
||||
# VictoriaLogs for external logging
|
||||
victoria-logs:
|
||||
image: victoriametrics/victoria-logs:latest
|
||||
container_name: victoria-logs
|
||||
ports:
|
||||
- "9428:9428"
|
||||
command:
|
||||
- -storageDataPath=/victoria-logs-data
|
||||
- -retentionPeriod=7d
|
||||
- -httpAuth.password=devpassword
|
||||
volumes:
|
||||
- ./victoria-logs-data:/victoria-logs-data
|
||||
restart: unless-stopped
|
||||
|
||||
# Test MinIO container
|
||||
test-minio:
|
||||
image: minio/minio:latest
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/ilyakaznacheev/cleanenv"
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
@@ -23,19 +22,30 @@ const (
|
||||
|
||||
type EnvVariables struct {
|
||||
IsTesting bool
|
||||
DatabaseDsn string `env:"DATABASE_DSN" required:"true"`
|
||||
EnvMode env_utils.EnvMode `env:"ENV_MODE" required:"true"`
|
||||
PostgresesInstallDir string `env:"POSTGRES_INSTALL_DIR"`
|
||||
MysqlInstallDir string `env:"MYSQL_INSTALL_DIR"`
|
||||
MariadbInstallDir string `env:"MARIADB_INSTALL_DIR"`
|
||||
MongodbInstallDir string `env:"MONGODB_INSTALL_DIR"`
|
||||
|
||||
ShowDbInstallationVerificationLogs bool `env:"SHOW_DB_INSTALLATION_VERIFICATION_LOGS"`
|
||||
// Internal database
|
||||
DatabaseDsn string `env:"DATABASE_DSN" required:"true"`
|
||||
// Internal Valkey
|
||||
ValkeyHost string `env:"VALKEY_HOST" required:"true"`
|
||||
ValkeyPort string `env:"VALKEY_PORT" required:"true"`
|
||||
ValkeyUsername string `env:"VALKEY_USERNAME" required:"true"`
|
||||
ValkeyPassword string `env:"VALKEY_PASSWORD" required:"true"`
|
||||
ValkeyIsSsl bool `env:"VALKEY_IS_SSL" required:"true"`
|
||||
|
||||
IsCloud bool `env:"IS_CLOUD"`
|
||||
TestLocalhost string `env:"TEST_LOCALHOST"`
|
||||
|
||||
ShowDbInstallationVerificationLogs bool `env:"SHOW_DB_INSTALLATION_VERIFICATION_LOGS"`
|
||||
IsSkipExternalResourcesTests bool `env:"IS_SKIP_EXTERNAL_RESOURCES_TESTS"`
|
||||
|
||||
NodeID string
|
||||
IsManyNodesMode bool `env:"IS_MANY_NODES_MODE"`
|
||||
IsPrimaryNode bool `env:"IS_PRIMARY_NODE"`
|
||||
IsBackupNode bool `env:"IS_BACKUP_NODE"`
|
||||
IsProcessingNode bool `env:"IS_PROCESSING_NODE"`
|
||||
NodeNetworkThroughputMBs int `env:"NODE_NETWORK_THROUGHPUT_MBPS"`
|
||||
|
||||
DataFolder string
|
||||
@@ -88,19 +98,16 @@ type EnvVariables struct {
|
||||
TestMongodb70Port string `env:"TEST_MONGODB_70_PORT"`
|
||||
TestMongodb82Port string `env:"TEST_MONGODB_82_PORT"`
|
||||
|
||||
// Valkey
|
||||
ValkeyHost string `env:"VALKEY_HOST" required:"true"`
|
||||
ValkeyPort string `env:"VALKEY_PORT" required:"true"`
|
||||
ValkeyUsername string `env:"VALKEY_USERNAME"`
|
||||
ValkeyPassword string `env:"VALKEY_PASSWORD"`
|
||||
ValkeyIsSsl bool `env:"VALKEY_IS_SSL" required:"true"`
|
||||
|
||||
// oauth
|
||||
GitHubClientID string `env:"GITHUB_CLIENT_ID"`
|
||||
GitHubClientSecret string `env:"GITHUB_CLIENT_SECRET"`
|
||||
GoogleClientID string `env:"GOOGLE_CLIENT_ID"`
|
||||
GoogleClientSecret string `env:"GOOGLE_CLIENT_SECRET"`
|
||||
|
||||
// Cloudflare Turnstile
|
||||
CloudflareTurnstileSecretKey string `env:"CLOUDFLARE_TURNSTILE_SECRET_KEY"`
|
||||
CloudflareTurnstileSiteKey string `env:"CLOUDFLARE_TURNSTILE_SITE_KEY"`
|
||||
|
||||
// testing Telegram
|
||||
TestTelegramBotToken string `env:"TEST_TELEGRAM_BOT_TOKEN"`
|
||||
TestTelegramChatID string `env:"TEST_TELEGRAM_CHAT_ID"`
|
||||
@@ -111,6 +118,15 @@ type EnvVariables struct {
|
||||
TestSupabaseUsername string `env:"TEST_SUPABASE_USERNAME"`
|
||||
TestSupabasePassword string `env:"TEST_SUPABASE_PASSWORD"`
|
||||
TestSupabaseDatabase string `env:"TEST_SUPABASE_DATABASE"`
|
||||
|
||||
// SMTP configuration (optional)
|
||||
SMTPHost string `env:"SMTP_HOST"`
|
||||
SMTPPort int `env:"SMTP_PORT"`
|
||||
SMTPUser string `env:"SMTP_USER"`
|
||||
SMTPPassword string `env:"SMTP_PASSWORD"`
|
||||
|
||||
// Application URL (optional) - used for email links
|
||||
DatabasusURL string `env:"DATABASUS_URL"`
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -176,6 +192,16 @@ func loadEnvVariables() {
|
||||
env.ShowDbInstallationVerificationLogs = true
|
||||
}
|
||||
|
||||
// Set default value for IsSkipExternalTests if not defined
|
||||
if os.Getenv("IS_SKIP_EXTERNAL_RESOURCES_TESTS") == "" {
|
||||
env.IsSkipExternalResourcesTests = false
|
||||
}
|
||||
|
||||
// Set default value for IsCloud if not defined
|
||||
if os.Getenv("IS_CLOUD") == "" {
|
||||
env.IsCloud = false
|
||||
}
|
||||
|
||||
for _, arg := range os.Args {
|
||||
if strings.Contains(arg, "test") {
|
||||
env.IsTesting = true
|
||||
@@ -183,6 +209,14 @@ func loadEnvVariables() {
|
||||
}
|
||||
}
|
||||
|
||||
// Check for external database override
|
||||
if externalDsn := os.Getenv("DANGEROUS_EXTERNAL_DATABASE_DSN"); externalDsn != "" {
|
||||
log.Warn(
|
||||
"Using DANGEROUS_EXTERNAL_DATABASE_DSN - connecting to external database instead of internal PostgreSQL",
|
||||
)
|
||||
env.DatabaseDsn = externalDsn
|
||||
}
|
||||
|
||||
if env.DatabaseDsn == "" {
|
||||
log.Error("DATABASE_DSN is empty")
|
||||
os.Exit(1)
|
||||
@@ -230,14 +264,17 @@ func loadEnvVariables() {
|
||||
env.ShowDbInstallationVerificationLogs,
|
||||
)
|
||||
|
||||
env.NodeID = uuid.New().String()
|
||||
if env.NodeNetworkThroughputMBs == 0 {
|
||||
env.NodeNetworkThroughputMBs = 125 // 1 Gbit/s
|
||||
}
|
||||
|
||||
if !env.IsManyNodesMode {
|
||||
env.IsPrimaryNode = true
|
||||
env.IsBackupNode = true
|
||||
env.IsProcessingNode = true
|
||||
}
|
||||
|
||||
if env.TestLocalhost == "" {
|
||||
env.TestLocalhost = "localhost"
|
||||
}
|
||||
|
||||
// Valkey
|
||||
@@ -250,6 +287,27 @@ func loadEnvVariables() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Check for external Valkey override
|
||||
if externalValkeyHost := os.Getenv("DANGEROUS_VALKEY_HOST"); externalValkeyHost != "" {
|
||||
log.Warn(
|
||||
"Using DANGEROUS_VALKEY_* variables - connecting to external Valkey instead of internal instance",
|
||||
)
|
||||
env.ValkeyHost = externalValkeyHost
|
||||
|
||||
if externalValkeyPort := os.Getenv("DANGEROUS_VALKEY_PORT"); externalValkeyPort != "" {
|
||||
env.ValkeyPort = externalValkeyPort
|
||||
}
|
||||
if externalValkeyUsername := os.Getenv("DANGEROUS_VALKEY_USERNAME"); externalValkeyUsername != "" {
|
||||
env.ValkeyUsername = externalValkeyUsername
|
||||
}
|
||||
if externalValkeyPassword := os.Getenv("DANGEROUS_VALKEY_PASSWORD"); externalValkeyPassword != "" {
|
||||
env.ValkeyPassword = externalValkeyPassword
|
||||
}
|
||||
if externalValkeyIsSsl := os.Getenv("DANGEROUS_VALKEY_IS_SSL"); externalValkeyIsSsl != "" {
|
||||
env.ValkeyIsSsl = externalValkeyIsSsl == "true"
|
||||
}
|
||||
}
|
||||
|
||||
// Store the data and temp folders one level below the root
|
||||
// (projectRoot/databasus-data -> /databasus-data)
|
||||
env.DataFolder = filepath.Join(filepath.Dir(backendRoot), "databasus-data", "backups")
|
||||
|
||||
@@ -2,34 +2,50 @@ package audit_logs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type AuditLogBackgroundService struct {
|
||||
auditLogService *AuditLogService
|
||||
logger *slog.Logger
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *AuditLogBackgroundService) Run(ctx context.Context) {
|
||||
s.logger.Info("Starting audit log cleanup background service")
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
s.logger.Info("Starting audit log cleanup background service")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.cleanOldAuditLogs(); err != nil {
|
||||
s.logger.Error("Failed to clean old audit logs", "error", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.cleanOldAuditLogs(); err != nil {
|
||||
s.logger.Error("Failed to clean old audit logs", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
@@ -14,8 +17,10 @@ var auditLogController = &AuditLogController{
|
||||
auditLogService,
|
||||
}
|
||||
var auditLogBackgroundService = &AuditLogBackgroundService{
|
||||
auditLogService,
|
||||
logger.GetLogger(),
|
||||
auditLogService: auditLogService,
|
||||
logger: logger.GetLogger(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
|
||||
func GetAuditLogService() *AuditLogService {
|
||||
@@ -30,8 +35,23 @@ func GetAuditLogBackgroundService() *AuditLogBackgroundService {
|
||||
return auditLogBackgroundService
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
users_services.GetUserService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
users_services.GetUserService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,24 +1,28 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/storages"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
task_registry "databasus-backend/internal/features/tasks/registry"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -35,70 +39,87 @@ type BackuperNode struct {
|
||||
storageService *storages.StorageService
|
||||
notificationSender backups_core.NotificationSender
|
||||
backupCancelManager *tasks_cancellation.TaskCancelManager
|
||||
tasksRegistry *task_registry.TaskNodesRegistry
|
||||
backupNodesRegistry *BackupNodesRegistry
|
||||
logger *slog.Logger
|
||||
createBackupUseCase backups_core.CreateBackupUsecase
|
||||
nodeID uuid.UUID
|
||||
|
||||
lastHeartbeat time.Time
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (n *BackuperNode) Run(ctx context.Context) {
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
wasAlreadyRun := n.hasRun.Load()
|
||||
|
||||
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
|
||||
n.runOnce.Do(func() {
|
||||
n.hasRun.Store(true)
|
||||
|
||||
backupNode := task_registry.TaskNode{
|
||||
ID: n.nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
}
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
|
||||
if err := n.tasksRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
|
||||
n.logger.Error("Failed to register node in registry", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
|
||||
|
||||
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
|
||||
n.MakeBackup(backupID, isCallNotifier)
|
||||
if err := n.tasksRegistry.PublishTaskCompletion(n.nodeID.String(), backupID); err != nil {
|
||||
n.logger.Error(
|
||||
"Failed to publish backup completion",
|
||||
"error",
|
||||
err,
|
||||
"backupID",
|
||||
backupID,
|
||||
)
|
||||
backupNode := BackupNode{
|
||||
ID: n.nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: time.Now().UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
if err := n.tasksRegistry.SubscribeNodeForTasksAssignment(n.nodeID.String(), backupHandler); err != nil {
|
||||
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := n.tasksRegistry.UnsubscribeNodeForTasksAssignments(); err != nil {
|
||||
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
|
||||
if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
|
||||
n.logger.Error("Failed to register node in registry", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(heartbeatTickerInterval)
|
||||
defer ticker.Stop()
|
||||
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
|
||||
go func() {
|
||||
n.MakeBackup(backupID, isCallNotifier)
|
||||
if err := n.backupNodesRegistry.PublishBackupCompletion(n.nodeID, backupID); err != nil {
|
||||
n.logger.Error(
|
||||
"Failed to publish backup completion",
|
||||
"error",
|
||||
err,
|
||||
"backupID",
|
||||
backupID,
|
||||
)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
|
||||
|
||||
if err := n.tasksRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
|
||||
n.logger.Error("Failed to unregister node from registry", "error", err)
|
||||
err := n.backupNodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID, backupHandler)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := n.backupNodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
|
||||
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return
|
||||
case <-ticker.C:
|
||||
n.sendHeartbeat(&backupNode)
|
||||
ticker := time.NewTicker(heartbeatTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
|
||||
|
||||
if err := n.backupNodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
|
||||
n.logger.Error("Failed to unregister node from registry", "error", err)
|
||||
}
|
||||
|
||||
return
|
||||
case <-ticker.C:
|
||||
n.sendHeartbeat(&backupNode)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", n))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,30 +161,73 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
|
||||
|
||||
start := time.Now().UTC()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
n.backupCancelManager.RegisterTask(backup.ID, cancel)
|
||||
defer n.backupCancelManager.UnregisterTask(backup.ID)
|
||||
|
||||
backupProgressListener := func(
|
||||
completedMBs float64,
|
||||
) {
|
||||
backup.BackupSizeMb = completedMBs
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
// Check size limit (0 = unlimited)
|
||||
if backupConfig.MaxBackupSizeMB > 0 &&
|
||||
completedMBs > float64(backupConfig.MaxBackupSizeMB) {
|
||||
errMsg := fmt.Sprintf(
|
||||
"backup size (%.2f MB) exceeded maximum allowed size (%d MB)",
|
||||
completedMBs,
|
||||
backupConfig.MaxBackupSizeMB,
|
||||
)
|
||||
|
||||
backup.Status = backups_core.BackupStatusFailed
|
||||
backup.IsSkipRetry = true
|
||||
backup.FailMessage = &errMsg
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to save backup with size exceeded error", "error", err)
|
||||
}
|
||||
cancel() // Cancel the backup context
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to update backup progress", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
n.backupCancelManager.RegisterTask(backup.ID, cancel)
|
||||
defer n.backupCancelManager.UnregisterTask(backup.ID)
|
||||
|
||||
backupMetadata, err := n.createBackupUseCase.Execute(
|
||||
ctx,
|
||||
backup.ID,
|
||||
backup,
|
||||
backupConfig,
|
||||
database,
|
||||
storage,
|
||||
backupProgressListener,
|
||||
)
|
||||
if err != nil {
|
||||
// Check if backup was already marked as failed by progress listener (e.g., size limit exceeded)
|
||||
// If so, skip error handling to avoid overwriting the status
|
||||
currentBackup, fetchErr := n.backupRepository.FindByID(backup.ID)
|
||||
if fetchErr == nil && currentBackup.Status == backups_core.BackupStatusFailed {
|
||||
n.logger.Warn(
|
||||
"Backup already marked as failed by progress listener, skipping error handling",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"failMessage",
|
||||
*currentBackup.FailMessage,
|
||||
)
|
||||
|
||||
// Still call notification for size limit failures
|
||||
n.SendBackupNotification(
|
||||
backupConfig,
|
||||
currentBackup,
|
||||
backups_config.NotificationBackupFailed,
|
||||
currentBackup.FailMessage,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
errMsg := err.Error()
|
||||
|
||||
// Log detailed error information for debugging
|
||||
@@ -201,7 +265,7 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
|
||||
// Delete partial backup from storage
|
||||
storage, storageErr := n.storageService.GetStorageByID(backup.StorageID)
|
||||
if storageErr == nil {
|
||||
if deleteErr := storage.DeleteFile(n.fieldEncryptor, backup.ID); deleteErr != nil {
|
||||
if deleteErr := storage.DeleteFile(n.fieldEncryptor, backup.FileName); deleteErr != nil {
|
||||
n.logger.Error(
|
||||
"Failed to delete partial backup file",
|
||||
"backupId",
|
||||
@@ -249,6 +313,13 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
|
||||
|
||||
// Update backup with encryption metadata if provided
|
||||
if backupMetadata != nil {
|
||||
backupMetadata.BackupID = backup.ID
|
||||
|
||||
if err := backupMetadata.Validate(); err != nil {
|
||||
n.logger.Error("Failed to validate backup metadata", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
backup.EncryptionSalt = backupMetadata.EncryptionSalt
|
||||
backup.EncryptionIV = backupMetadata.EncryptionIV
|
||||
backup.Encryption = backupMetadata.Encryption
|
||||
@@ -259,6 +330,39 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
|
||||
return
|
||||
}
|
||||
|
||||
// Save metadata file to storage
|
||||
if backupMetadata != nil {
|
||||
metadataJSON, err := json.Marshal(backupMetadata)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to marshal backup metadata to JSON",
|
||||
"backupId", backup.ID,
|
||||
"error", err,
|
||||
)
|
||||
} else {
|
||||
metadataReader := bytes.NewReader(metadataJSON)
|
||||
metadataFileName := backup.FileName + ".metadata"
|
||||
|
||||
if err := storage.SaveFile(
|
||||
context.Background(),
|
||||
n.fieldEncryptor,
|
||||
n.logger,
|
||||
metadataFileName,
|
||||
metadataReader,
|
||||
); err != nil {
|
||||
n.logger.Error("Failed to save backup metadata file to storage",
|
||||
"backupId", backup.ID,
|
||||
"fileName", metadataFileName,
|
||||
"error", err,
|
||||
)
|
||||
} else {
|
||||
n.logger.Info("Backup metadata file saved successfully",
|
||||
"backupId", backup.ID,
|
||||
"fileName", metadataFileName,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update database last backup time
|
||||
now := time.Now().UTC()
|
||||
if updateErr := n.databaseService.SetLastBackupTime(databaseID, now); updateErr != nil {
|
||||
@@ -357,9 +461,9 @@ func (n *BackuperNode) SendBackupNotification(
|
||||
}
|
||||
}
|
||||
|
||||
func (n *BackuperNode) sendHeartbeat(backupNode *task_registry.TaskNode) {
|
||||
func (n *BackuperNode) sendHeartbeat(backupNode *BackupNode) {
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
if err := n.tasksRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *backupNode); err != nil {
|
||||
if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *backupNode); err != nil {
|
||||
n.logger.Error("Failed to send heartbeat", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
common "databasus-backend/internal/features/backups/backups/common"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -18,7 +15,6 @@ import (
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
@@ -158,35 +154,120 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
type CreateFailedBackupUsecase struct {
|
||||
}
|
||||
func Test_BackupSizeLimits(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
func (uc *CreateFailedBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
backupProgressListener(10)
|
||||
return nil, errors.New("backup failed")
|
||||
}
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
type CreateSuccessBackupUsecase struct{}
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
func (uc *CreateSuccessBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
backupProgressListener(10)
|
||||
return &common.BackupMetadata{
|
||||
EncryptionSalt: nil,
|
||||
EncryptionIV: nil,
|
||||
Encryption: backups_config.BackupEncryptionNone,
|
||||
}, nil
|
||||
t.Run("UnlimitedSize_MaxBackupSizeMBIsZero_BackupCompletes", func(t *testing.T) {
|
||||
// Enable backups with unlimited size (0)
|
||||
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
backupConfig.MaxBackupSizeMB = 0 // unlimited
|
||||
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.createBackupUseCase = &CreateLargeBackupUsecase{}
|
||||
|
||||
// Create a backup record
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode.MakeBackup(backup.ID, false)
|
||||
|
||||
// Verify backup completed successfully even with large size
|
||||
updatedBackup, err := backupRepository.FindByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, updatedBackup.Status)
|
||||
assert.Equal(t, float64(10000), updatedBackup.BackupSizeMb)
|
||||
assert.Nil(t, updatedBackup.FailMessage)
|
||||
})
|
||||
|
||||
t.Run("SizeExceeded_BackupFailedWithIsSkipRetry", func(t *testing.T) {
|
||||
// Enable backups with 5 MB limit
|
||||
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
backupConfig.MaxBackupSizeMB = 5
|
||||
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.createBackupUseCase = &CreateProgressiveBackupUsecase{}
|
||||
|
||||
// Create a backup record
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode.MakeBackup(backup.ID, false)
|
||||
|
||||
// Verify backup was marked as failed with IsSkipRetry=true
|
||||
updatedBackup, err := backupRepository.FindByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusFailed, updatedBackup.Status)
|
||||
assert.True(t, updatedBackup.IsSkipRetry)
|
||||
assert.NotNil(t, updatedBackup.FailMessage)
|
||||
assert.Contains(t, *updatedBackup.FailMessage, "exceeded maximum allowed size")
|
||||
assert.Contains(t, *updatedBackup.FailMessage, "10.00 MB")
|
||||
assert.Contains(t, *updatedBackup.FailMessage, "5 MB")
|
||||
assert.Greater(t, updatedBackup.BackupSizeMb, float64(5))
|
||||
})
|
||||
|
||||
t.Run("SizeWithinLimit_BackupCompletes", func(t *testing.T) {
|
||||
// Enable backups with 100 MB limit
|
||||
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
backupConfig.MaxBackupSizeMB = 100
|
||||
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.createBackupUseCase = &CreateMediumBackupUsecase{}
|
||||
|
||||
// Create a backup record
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode.MakeBackup(backup.ID, false)
|
||||
|
||||
// Verify backup completed successfully
|
||||
updatedBackup, err := backupRepository.FindByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, updatedBackup.Status)
|
||||
assert.Equal(t, float64(50), updatedBackup.BackupSizeMb)
|
||||
assert.Nil(t, updatedBackup.FailMessage)
|
||||
})
|
||||
}
|
||||
|
||||
461
backend/internal/features/backups/backups/backuping/cleaner.go
Normal file
461
backend/internal/features/backups/backups/backuping/cleaner.go
Normal file
@@ -0,0 +1,461 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/storages"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/period"
|
||||
)
|
||||
|
||||
const (
|
||||
cleanerTickerInterval = 1 * time.Minute
|
||||
recentBackupGracePeriod = 60 * time.Minute
|
||||
)
|
||||
|
||||
type BackupCleaner struct {
|
||||
backupRepository *backups_core.BackupRepository
|
||||
storageService *storages.StorageService
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
fieldEncryptor util_encryption.FieldEncryptor
|
||||
logger *slog.Logger
|
||||
backupRemoveListeners []backups_core.BackupRemoveListener
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) Run(ctx context.Context) {
|
||||
wasAlreadyRun := c.hasRun.Load()
|
||||
|
||||
c.runOnce.Do(func() {
|
||||
c.hasRun.Store(true)
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(cleanerTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := c.cleanByRetentionPolicy(); err != nil {
|
||||
c.logger.Error("Failed to clean backups by retention policy", "error", err)
|
||||
}
|
||||
|
||||
if err := c.cleanExceededBackups(); err != nil {
|
||||
c.logger.Error("Failed to clean exceeded backups", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", c))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) DeleteBackup(backup *backups_core.Backup) error {
|
||||
for _, listener := range c.backupRemoveListeners {
|
||||
if err := listener.OnBeforeBackupRemove(backup); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
storage, err := c.storageService.GetStorageByID(backup.StorageID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = storage.DeleteFile(c.fieldEncryptor, backup.FileName)
|
||||
if err != nil {
|
||||
// we do not return error here, because sometimes clean up performed
|
||||
// before unavailable storage removal or change - therefore we should
|
||||
// proceed even in case of error. It's possible that some S3 or
|
||||
// storage is not available yet, it should not block us
|
||||
c.logger.Error("Failed to delete backup file", "error", err)
|
||||
}
|
||||
|
||||
metadataFileName := backup.FileName + ".metadata"
|
||||
if err := storage.DeleteFile(c.fieldEncryptor, metadataFileName); err != nil {
|
||||
c.logger.Error("Failed to delete backup metadata file", "error", err)
|
||||
}
|
||||
|
||||
return c.backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) AddBackupRemoveListener(listener backups_core.BackupRemoveListener) {
|
||||
c.backupRemoveListeners = append(c.backupRemoveListeners, listener)
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanByRetentionPolicy() error {
|
||||
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
var cleanErr error
|
||||
|
||||
switch backupConfig.RetentionPolicyType {
|
||||
case backups_config.RetentionPolicyTypeCount:
|
||||
cleanErr = c.cleanByCount(backupConfig)
|
||||
case backups_config.RetentionPolicyTypeGFS:
|
||||
cleanErr = c.cleanByGFS(backupConfig)
|
||||
default:
|
||||
cleanErr = c.cleanByTimePeriod(backupConfig)
|
||||
}
|
||||
|
||||
if cleanErr != nil {
|
||||
c.logger.Error(
|
||||
"Failed to clean backups by retention policy",
|
||||
"databaseId", backupConfig.DatabaseID,
|
||||
"policy", backupConfig.RetentionPolicyType,
|
||||
"error", cleanErr,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanExceededBackups() error {
|
||||
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
if backupConfig.MaxBackupsTotalSizeMB <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := c.cleanExceededBackupsForDatabase(
|
||||
backupConfig.DatabaseID,
|
||||
backupConfig.MaxBackupsTotalSizeMB,
|
||||
); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to clean exceeded backups for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanByTimePeriod(backupConfig *backups_config.BackupConfig) error {
|
||||
if backupConfig.RetentionTimePeriod == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if backupConfig.RetentionTimePeriod == period.PeriodForever {
|
||||
return nil
|
||||
}
|
||||
|
||||
storeDuration := backupConfig.RetentionTimePeriod.ToDuration()
|
||||
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
|
||||
|
||||
oldBackups, err := c.backupRepository.FindBackupsBeforeDate(
|
||||
backupConfig.DatabaseID,
|
||||
dateBeforeBackupsShouldBeDeleted,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to find old backups for database %s: %w",
|
||||
backupConfig.DatabaseID,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
for _, backup := range oldBackups {
|
||||
if isRecentBackup(backup) {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := c.DeleteBackup(backup); err != nil {
|
||||
c.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
c.logger.Info(
|
||||
"Deleted old backup",
|
||||
"backupId", backup.ID,
|
||||
"databaseId", backupConfig.DatabaseID,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanByCount(backupConfig *backups_config.BackupConfig) error {
|
||||
if backupConfig.RetentionCount <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
completedBackups, err := c.backupRepository.FindByDatabaseIdAndStatus(
|
||||
backupConfig.DatabaseID,
|
||||
backups_core.BackupStatusCompleted,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to find completed backups for database %s: %w",
|
||||
backupConfig.DatabaseID,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
// completedBackups are ordered newest first; delete everything beyond position RetentionCount
|
||||
if len(completedBackups) <= backupConfig.RetentionCount {
|
||||
return nil
|
||||
}
|
||||
|
||||
toDelete := completedBackups[backupConfig.RetentionCount:]
|
||||
for _, backup := range toDelete {
|
||||
if isRecentBackup(backup) {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := c.DeleteBackup(backup); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to delete backup by count policy",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
c.logger.Info(
|
||||
"Deleted backup by count policy",
|
||||
"backupId", backup.ID,
|
||||
"databaseId", backupConfig.DatabaseID,
|
||||
"retentionCount", backupConfig.RetentionCount,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanByGFS(backupConfig *backups_config.BackupConfig) error {
|
||||
if backupConfig.RetentionGfsHours <= 0 && backupConfig.RetentionGfsDays <= 0 &&
|
||||
backupConfig.RetentionGfsWeeks <= 0 && backupConfig.RetentionGfsMonths <= 0 &&
|
||||
backupConfig.RetentionGfsYears <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
completedBackups, err := c.backupRepository.FindByDatabaseIdAndStatus(
|
||||
backupConfig.DatabaseID,
|
||||
backups_core.BackupStatusCompleted,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to find completed backups for database %s: %w",
|
||||
backupConfig.DatabaseID,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
keepSet := buildGFSKeepSet(
|
||||
completedBackups,
|
||||
backupConfig.RetentionGfsHours,
|
||||
backupConfig.RetentionGfsDays,
|
||||
backupConfig.RetentionGfsWeeks,
|
||||
backupConfig.RetentionGfsMonths,
|
||||
backupConfig.RetentionGfsYears,
|
||||
)
|
||||
|
||||
for _, backup := range completedBackups {
|
||||
if keepSet[backup.ID] {
|
||||
continue
|
||||
}
|
||||
|
||||
if isRecentBackup(backup) {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := c.DeleteBackup(backup); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to delete backup by GFS policy",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
c.logger.Info(
|
||||
"Deleted backup by GFS policy",
|
||||
"backupId", backup.ID,
|
||||
"databaseId", backupConfig.DatabaseID,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanExceededBackupsForDatabase(
|
||||
databaseID uuid.UUID,
|
||||
limitperDbMB int64,
|
||||
) error {
|
||||
for {
|
||||
backupsTotalSizeMB, err := c.backupRepository.GetTotalSizeByDatabase(databaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if backupsTotalSizeMB <= float64(limitperDbMB) {
|
||||
break
|
||||
}
|
||||
|
||||
oldestBackups, err := c.backupRepository.FindOldestByDatabaseExcludingInProgress(
|
||||
databaseID,
|
||||
1,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(oldestBackups) == 0 {
|
||||
c.logger.Warn(
|
||||
"No backups to delete but still over limit",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"totalSizeMB",
|
||||
backupsTotalSizeMB,
|
||||
"limitMB",
|
||||
limitperDbMB,
|
||||
)
|
||||
break
|
||||
}
|
||||
|
||||
backup := oldestBackups[0]
|
||||
if isRecentBackup(backup) {
|
||||
c.logger.Warn(
|
||||
"Oldest backup is too recent to delete, stopping size cleanup",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"totalSizeMB",
|
||||
backupsTotalSizeMB,
|
||||
"limitMB",
|
||||
limitperDbMB,
|
||||
)
|
||||
break
|
||||
}
|
||||
|
||||
if err := c.DeleteBackup(backup); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to delete exceeded backup",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
c.logger.Info(
|
||||
"Deleted exceeded backup",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"backupSizeMB",
|
||||
backup.BackupSizeMb,
|
||||
"totalSizeMB",
|
||||
backupsTotalSizeMB,
|
||||
"limitMB",
|
||||
limitperDbMB,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func isRecentBackup(backup *backups_core.Backup) bool {
|
||||
return time.Since(backup.CreatedAt) < recentBackupGracePeriod
|
||||
}
|
||||
|
||||
// buildGFSKeepSet determines which backups to retain under the GFS rotation scheme.
|
||||
// Backups must be sorted newest-first. A backup can fill multiple slots simultaneously
|
||||
// (e.g. the newest backup of a year also fills the monthly, weekly, daily, and hourly slot).
|
||||
func buildGFSKeepSet(
|
||||
backups []*backups_core.Backup,
|
||||
hours, days, weeks, months, years int,
|
||||
) map[uuid.UUID]bool {
|
||||
keep := make(map[uuid.UUID]bool)
|
||||
|
||||
hoursSeen := make(map[string]bool)
|
||||
daysSeen := make(map[string]bool)
|
||||
weeksSeen := make(map[string]bool)
|
||||
monthsSeen := make(map[string]bool)
|
||||
yearsSeen := make(map[string]bool)
|
||||
|
||||
hoursKept, daysKept, weeksKept, monthsKept, yearsKept := 0, 0, 0, 0, 0
|
||||
|
||||
for _, backup := range backups {
|
||||
t := backup.CreatedAt
|
||||
|
||||
hourKey := t.Format("2006-01-02-15")
|
||||
dayKey := t.Format("2006-01-02")
|
||||
weekYear, week := t.ISOWeek()
|
||||
weekKey := fmt.Sprintf("%d-%02d", weekYear, week)
|
||||
monthKey := t.Format("2006-01")
|
||||
yearKey := t.Format("2006")
|
||||
|
||||
if hours > 0 && hoursKept < hours && !hoursSeen[hourKey] {
|
||||
keep[backup.ID] = true
|
||||
hoursSeen[hourKey] = true
|
||||
hoursKept++
|
||||
}
|
||||
|
||||
if days > 0 && daysKept < days && !daysSeen[dayKey] {
|
||||
keep[backup.ID] = true
|
||||
daysSeen[dayKey] = true
|
||||
daysKept++
|
||||
}
|
||||
|
||||
if weeks > 0 && weeksKept < weeks && !weeksSeen[weekKey] {
|
||||
keep[backup.ID] = true
|
||||
weeksSeen[weekKey] = true
|
||||
weeksKept++
|
||||
}
|
||||
|
||||
if months > 0 && monthsKept < months && !monthsSeen[monthKey] {
|
||||
keep[backup.ID] = true
|
||||
monthsSeen[monthKey] = true
|
||||
monthsKept++
|
||||
}
|
||||
|
||||
if years > 0 && yearsKept < years && !yearsSeen[yearKey] {
|
||||
keep[backup.ID] = true
|
||||
yearsSeen[yearKey] = true
|
||||
yearsKept++
|
||||
}
|
||||
}
|
||||
|
||||
return keep
|
||||
}
|
||||
1568
backend/internal/features/backups/backups/backuping/cleaner_test.go
Normal file
1568
backend/internal/features/backups/backups/backuping/cleaner_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,12 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/usecases"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
@@ -9,29 +14,39 @@ import (
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
task_registry "databasus-backend/internal/features/tasks/registry"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var backupRepository = &backups_core.BackupRepository{}
|
||||
|
||||
var taskCancelManager = tasks_cancellation.GetTaskCancelManager()
|
||||
|
||||
var nodesRegistry = task_registry.GetTaskNodesRegistry()
|
||||
var backupCleaner = &BackupCleaner{
|
||||
backupRepository,
|
||||
storages.GetStorageService(),
|
||||
backups_config.GetBackupConfigService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
logger.GetLogger(),
|
||||
[]backups_core.BackupRemoveListener{},
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
|
||||
var backupNodesRegistry = &BackupNodesRegistry{
|
||||
cache_utils.GetValkeyClient(),
|
||||
logger.GetLogger(),
|
||||
cache_utils.DefaultCacheTimeout,
|
||||
cache_utils.NewPubSubManager(),
|
||||
cache_utils.NewPubSubManager(),
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
|
||||
func getNodeID() uuid.UUID {
|
||||
nodeIDStr := config.GetEnv().NodeID
|
||||
nodeID, err := uuid.Parse(nodeIDStr)
|
||||
if err != nil {
|
||||
logger.GetLogger().Error("Failed to parse node ID from config", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
return nodeID
|
||||
return uuid.New()
|
||||
}
|
||||
|
||||
var backuperNode = &BackuperNode{
|
||||
@@ -43,23 +58,27 @@ var backuperNode = &BackuperNode{
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
taskCancelManager,
|
||||
nodesRegistry,
|
||||
backupNodesRegistry,
|
||||
logger.GetLogger(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
getNodeID(),
|
||||
time.Time{},
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
|
||||
var backupsScheduler = &BackupsScheduler{
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
taskCancelManager,
|
||||
nodesRegistry,
|
||||
backupNodesRegistry,
|
||||
databases.GetDatabaseService(),
|
||||
time.Now().UTC(),
|
||||
logger.GetLogger(),
|
||||
make(map[uuid.UUID]BackupToNodeRelation),
|
||||
backuperNode,
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
|
||||
func GetBackupsScheduler() *BackupsScheduler {
|
||||
@@ -69,3 +88,11 @@ func GetBackupsScheduler() *BackupsScheduler {
|
||||
func GetBackuperNode() *BackuperNode {
|
||||
return backuperNode
|
||||
}
|
||||
|
||||
func GetBackupNodesRegistry() *BackupNodesRegistry {
|
||||
return backupNodesRegistry
|
||||
}
|
||||
|
||||
func GetBackupCleaner() *BackupCleaner {
|
||||
return backupCleaner
|
||||
}
|
||||
|
||||
@@ -1,8 +1,34 @@
|
||||
package backuping
|
||||
|
||||
import "github.com/google/uuid"
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type BackupToNodeRelation struct {
|
||||
NodeID uuid.UUID `json:"nodeId"`
|
||||
BackupsIDs []uuid.UUID `json:"backupsIds"`
|
||||
}
|
||||
|
||||
type BackupNode struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
ThroughputMBs int `json:"throughputMBs"`
|
||||
LastHeartbeat time.Time `json:"lastHeartbeat"`
|
||||
}
|
||||
|
||||
type BackupNodeStats struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
ActiveBackups int `json:"activeBackups"`
|
||||
}
|
||||
|
||||
type BackupSubmitMessage struct {
|
||||
NodeID uuid.UUID `json:"nodeId"`
|
||||
BackupID uuid.UUID `json:"backupId"`
|
||||
IsCallNotifier bool `json:"isCallNotifier"`
|
||||
}
|
||||
|
||||
type BackupCompletionMessage struct {
|
||||
NodeID uuid.UUID `json:"nodeId"`
|
||||
BackupID uuid.UUID `json:"backupId"`
|
||||
}
|
||||
|
||||
@@ -1,8 +1,19 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
common "databasus-backend/internal/features/backups/backups/common"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
@@ -17,3 +28,168 @@ func (m *MockNotificationSender) SendNotification(
|
||||
) {
|
||||
m.Called(notifier, title, message)
|
||||
}
|
||||
|
||||
type CreateFailedBackupUsecase struct{}
|
||||
|
||||
func (uc *CreateFailedBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
backupProgressListener(10)
|
||||
return nil, errors.New("backup failed")
|
||||
}
|
||||
|
||||
type CreateSuccessBackupUsecase struct{}
|
||||
|
||||
func (uc *CreateSuccessBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
backupProgressListener(10)
|
||||
return &common.BackupMetadata{
|
||||
EncryptionSalt: nil,
|
||||
EncryptionIV: nil,
|
||||
Encryption: backups_config.BackupEncryptionNone,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateLargeBackupUsecase simulates a large backup (10000 MB)
|
||||
type CreateLargeBackupUsecase struct{}
|
||||
|
||||
func (uc *CreateLargeBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
backupProgressListener(10000)
|
||||
return &common.BackupMetadata{
|
||||
EncryptionSalt: nil,
|
||||
EncryptionIV: nil,
|
||||
Encryption: backups_config.BackupEncryptionNone,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateProgressiveBackupUsecase simulates progressive size updates that exceed limit
|
||||
type CreateProgressiveBackupUsecase struct{}
|
||||
|
||||
func (uc *CreateProgressiveBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
// Simulate progressive backup that grows beyond limit
|
||||
backupProgressListener(1)
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
backupProgressListener(3)
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
backupProgressListener(5)
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
backupProgressListener(10) // This exceeds the 5 MB limit
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
// Should not reach here due to cancellation
|
||||
return &common.BackupMetadata{
|
||||
EncryptionSalt: nil,
|
||||
EncryptionIV: nil,
|
||||
Encryption: backups_config.BackupEncryptionNone,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateMediumBackupUsecase simulates a 50 MB backup
|
||||
type CreateMediumBackupUsecase struct{}
|
||||
|
||||
func (uc *CreateMediumBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
backupProgressListener(50)
|
||||
return &common.BackupMetadata{
|
||||
EncryptionSalt: nil,
|
||||
EncryptionIV: nil,
|
||||
Encryption: backups_config.BackupEncryptionNone,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// MockTrackingBackupUsecase tracks backup use case calls for testing parallel execution
|
||||
type MockTrackingBackupUsecase struct {
|
||||
callCount atomic.Int32
|
||||
calledBackupIDs chan uuid.UUID
|
||||
}
|
||||
|
||||
func NewMockTrackingBackupUsecase() *MockTrackingBackupUsecase {
|
||||
return &MockTrackingBackupUsecase{
|
||||
calledBackupIDs: make(chan uuid.UUID, 10),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockTrackingBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
m.callCount.Add(1)
|
||||
|
||||
// Send backup ID to channel (non-blocking)
|
||||
select {
|
||||
case m.calledBackupIDs <- backup.ID:
|
||||
default:
|
||||
}
|
||||
|
||||
// Simulate backup work
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
backupProgressListener(10)
|
||||
|
||||
return &common.BackupMetadata{
|
||||
EncryptionSalt: nil,
|
||||
EncryptionIV: nil,
|
||||
Encryption: backups_config.BackupEncryptionNone,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MockTrackingBackupUsecase) GetCallCount() int32 {
|
||||
return m.callCount.Load()
|
||||
}
|
||||
|
||||
func (m *MockTrackingBackupUsecase) GetCalledBackupIDs() []uuid.UUID {
|
||||
ids := []uuid.UUID{}
|
||||
for {
|
||||
select {
|
||||
case id := <-m.calledBackupIDs:
|
||||
ids = append(ids, id)
|
||||
default:
|
||||
return ids
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package task_registry
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
@@ -15,64 +17,73 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
nodeInfoKeyPrefix = "node:"
|
||||
nodeInfoKeySuffix = ":info"
|
||||
nodeActiveTasksPrefix = "node:"
|
||||
nodeActiveTasksSuffix = ":active_tasks"
|
||||
taskSubmitChannel = "task:submit"
|
||||
taskCompletionChannel = "task:completion"
|
||||
nodeInfoKeyPrefix = "backup:node:"
|
||||
nodeInfoKeySuffix = ":info"
|
||||
nodeActiveBackupsPrefix = "backup:node:"
|
||||
nodeActiveBackupsSuffix = ":active_backups"
|
||||
backupSubmitChannel = "backup:submit"
|
||||
backupCompletionChannel = "backup:completion"
|
||||
|
||||
deadNodeThreshold = 2 * time.Minute
|
||||
cleanupTickerInterval = 1 * time.Second
|
||||
)
|
||||
|
||||
// TaskNodesRegistry helps to sync tasks scheduler (backuping or restoring)
|
||||
// and task nodes which are used for network-intensive tasks processing
|
||||
// BackupNodesRegistry helps to sync backups scheduler and backup nodes.
|
||||
//
|
||||
// Features:
|
||||
// - Track node availability and load level
|
||||
// - Assign from scheduler to node tasks needed to be processed
|
||||
// - Notify scheduler from node about task completion
|
||||
// - Assign from scheduler to node backups needed to be processed
|
||||
// - Notify scheduler from node about backup completion
|
||||
//
|
||||
// Important things to remember:
|
||||
// - Node can contain different tasks types so when task is assigned
|
||||
// or node's tasks cleaned - should be performed DB check in DB
|
||||
// that task with this ID exists for this task type at all
|
||||
// - Nodes without heathbeat for more than 2 minutes are not included
|
||||
// - Nodes without heartbeat for more than 2 minutes are not included
|
||||
// in available nodes list and stats
|
||||
//
|
||||
// Cleanup dead nodes performed on 2 levels:
|
||||
// - List and stats functions do not return dead nodes
|
||||
// - Periodically dead nodes are cleaned up in cache (to not
|
||||
// accumulate too many dead nodes in cache)
|
||||
type TaskNodesRegistry struct {
|
||||
type BackupNodesRegistry struct {
|
||||
client valkey.Client
|
||||
logger *slog.Logger
|
||||
timeout time.Duration
|
||||
pubsubTasks *cache_utils.PubSubManager
|
||||
pubsubBackups *cache_utils.PubSubManager
|
||||
pubsubCompletions *cache_utils.PubSubManager
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) Run(ctx context.Context) {
|
||||
if err := r.cleanupDeadNodes(); err != nil {
|
||||
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
|
||||
}
|
||||
func (r *BackupNodesRegistry) Run(ctx context.Context) {
|
||||
wasAlreadyRun := r.hasRun.Load()
|
||||
|
||||
ticker := time.NewTicker(cleanupTickerInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := r.cleanupDeadNodes(); err != nil {
|
||||
r.logger.Error("Failed to cleanup dead nodes", "error", err)
|
||||
r.runOnce.Do(func() {
|
||||
r.hasRun.Store(true)
|
||||
|
||||
if err := r.cleanupDeadNodes(); err != nil {
|
||||
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(cleanupTickerInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := r.cleanupDeadNodes(); err != nil {
|
||||
r.logger.Error("Failed to cleanup dead nodes", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", r))
|
||||
}
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) GetAvailableNodes() ([]TaskNode, error) {
|
||||
func (r *BackupNodesRegistry) GetAvailableNodes() ([]BackupNode, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
@@ -104,7 +115,7 @@ func (r *TaskNodesRegistry) GetAvailableNodes() ([]TaskNode, error) {
|
||||
}
|
||||
|
||||
if len(allKeys) == 0 {
|
||||
return []TaskNode{}, nil
|
||||
return []BackupNode{}, nil
|
||||
}
|
||||
|
||||
keyDataMap, err := r.pipelineGetKeys(allKeys)
|
||||
@@ -113,14 +124,15 @@ func (r *TaskNodesRegistry) GetAvailableNodes() ([]TaskNode, error) {
|
||||
}
|
||||
|
||||
threshold := time.Now().UTC().Add(-deadNodeThreshold)
|
||||
var nodes []TaskNode
|
||||
var nodes []BackupNode
|
||||
|
||||
for key, data := range keyDataMap {
|
||||
// Skip if the key doesn't exist (data is empty)
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var node TaskNode
|
||||
var node BackupNode
|
||||
if err := json.Unmarshal(data, &node); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal node data", "key", key, "error", err)
|
||||
continue
|
||||
@@ -141,13 +153,13 @@ func (r *TaskNodesRegistry) GetAvailableNodes() ([]TaskNode, error) {
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) GetNodesStats() ([]TaskNodeStats, error) {
|
||||
func (r *BackupNodesRegistry) GetBackupNodesStats() ([]BackupNodeStats, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
var allKeys []string
|
||||
cursor := uint64(0)
|
||||
pattern := nodeActiveTasksPrefix + "*" + nodeActiveTasksSuffix
|
||||
pattern := nodeActiveBackupsPrefix + "*" + nodeActiveBackupsSuffix
|
||||
|
||||
for {
|
||||
result := r.client.Do(
|
||||
@@ -156,7 +168,7 @@ func (r *TaskNodesRegistry) GetNodesStats() ([]TaskNodeStats, error) {
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return nil, fmt.Errorf("failed to scan active tasks keys: %w", result.Error())
|
||||
return nil, fmt.Errorf("failed to scan active backups keys: %w", result.Error())
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
@@ -173,18 +185,18 @@ func (r *TaskNodesRegistry) GetNodesStats() ([]TaskNodeStats, error) {
|
||||
}
|
||||
|
||||
if len(allKeys) == 0 {
|
||||
return []TaskNodeStats{}, nil
|
||||
return []BackupNodeStats{}, nil
|
||||
}
|
||||
|
||||
keyDataMap, err := r.pipelineGetKeys(allKeys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pipeline get active tasks keys: %w", err)
|
||||
return nil, fmt.Errorf("failed to pipeline get active backups keys: %w", err)
|
||||
}
|
||||
|
||||
var nodeInfoKeys []string
|
||||
nodeIDToStatsKey := make(map[string]string)
|
||||
for key := range keyDataMap {
|
||||
nodeID := r.extractNodeIDFromKey(key, nodeActiveTasksPrefix, nodeActiveTasksSuffix)
|
||||
nodeID := r.extractNodeIDFromKey(key, nodeActiveBackupsPrefix, nodeActiveBackupsSuffix)
|
||||
nodeIDStr := nodeID.String()
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeIDStr, nodeInfoKeySuffix)
|
||||
nodeInfoKeys = append(nodeInfoKeys, infoKey)
|
||||
@@ -197,14 +209,14 @@ func (r *TaskNodesRegistry) GetNodesStats() ([]TaskNodeStats, error) {
|
||||
}
|
||||
|
||||
threshold := time.Now().UTC().Add(-deadNodeThreshold)
|
||||
var stats []TaskNodeStats
|
||||
var stats []BackupNodeStats
|
||||
for infoKey, nodeData := range nodeInfoMap {
|
||||
// Skip if the info key doesn't exist (nodeData is empty)
|
||||
if len(nodeData) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var node TaskNode
|
||||
var node BackupNode
|
||||
if err := json.Unmarshal(nodeData, &node); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal node data", "key", infoKey, "error", err)
|
||||
continue
|
||||
@@ -223,13 +235,13 @@ func (r *TaskNodesRegistry) GetNodesStats() ([]TaskNodeStats, error) {
|
||||
tasksData := keyDataMap[statsKey]
|
||||
count, err := r.parseIntFromBytes(tasksData)
|
||||
if err != nil {
|
||||
r.logger.Warn("Failed to parse active tasks count", "key", statsKey, "error", err)
|
||||
r.logger.Warn("Failed to parse active backups count", "key", statsKey, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
stat := TaskNodeStats{
|
||||
ID: node.ID,
|
||||
ActiveTasks: int(count),
|
||||
stat := BackupNodeStats{
|
||||
ID: node.ID,
|
||||
ActiveBackups: int(count),
|
||||
}
|
||||
stats = append(stats, stat)
|
||||
}
|
||||
@@ -237,16 +249,16 @@ func (r *TaskNodesRegistry) GetNodesStats() ([]TaskNodeStats, error) {
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) IncrementTasksInProgress(nodeID string) error {
|
||||
func (r *BackupNodesRegistry) IncrementBackupsInProgress(nodeID uuid.UUID) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeActiveTasksPrefix, nodeID, nodeActiveTasksSuffix)
|
||||
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID.String(), nodeActiveBackupsSuffix)
|
||||
result := r.client.Do(ctx, r.client.B().Incr().Key(key).Build())
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to increment tasks in progress for node %s: %w",
|
||||
"failed to increment backups in progress for node %s: %w",
|
||||
nodeID,
|
||||
result.Error(),
|
||||
)
|
||||
@@ -255,16 +267,16 @@ func (r *TaskNodesRegistry) IncrementTasksInProgress(nodeID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) DecrementTasksInProgress(nodeID string) error {
|
||||
func (r *BackupNodesRegistry) DecrementBackupsInProgress(nodeID uuid.UUID) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeActiveTasksPrefix, nodeID, nodeActiveTasksSuffix)
|
||||
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID.String(), nodeActiveBackupsSuffix)
|
||||
result := r.client.Do(ctx, r.client.B().Decr().Key(key).Build())
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to decrement tasks in progress for node %s: %w",
|
||||
"failed to decrement backups in progress for node %s: %w",
|
||||
nodeID,
|
||||
result.Error(),
|
||||
)
|
||||
@@ -279,13 +291,13 @@ func (r *TaskNodesRegistry) DecrementTasksInProgress(nodeID string) error {
|
||||
setCtx, setCancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
r.client.Do(setCtx, r.client.B().Set().Key(key).Value("0").Build())
|
||||
setCancel()
|
||||
r.logger.Warn("Active tasks counter went below 0, reset to 0", "nodeID", nodeID)
|
||||
r.logger.Warn("Active backups counter went below 0, reset to 0", "nodeID", nodeID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) HearthbeatNodeInRegistry(now time.Time, node TaskNode) error {
|
||||
func (r *BackupNodesRegistry) HearthbeatNodeInRegistry(now time.Time, backupNode BackupNode) error {
|
||||
if now.IsZero() {
|
||||
return fmt.Errorf("cannot register node with zero heartbeat timestamp")
|
||||
}
|
||||
@@ -293,36 +305,36 @@ func (r *TaskNodesRegistry) HearthbeatNodeInRegistry(now time.Time, node TaskNod
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
node.LastHeartbeat = now
|
||||
backupNode.LastHeartbeat = now
|
||||
|
||||
data, err := json.Marshal(node)
|
||||
data, err := json.Marshal(backupNode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal node: %w", err)
|
||||
return fmt.Errorf("failed to marshal backup node: %w", err)
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node.ID.String(), nodeInfoKeySuffix)
|
||||
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, backupNode.ID.String(), nodeInfoKeySuffix)
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Set().Key(key).Value(string(data)).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to register node %s: %w", node.ID, result.Error())
|
||||
return fmt.Errorf("failed to register node %s: %w", backupNode.ID, result.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) UnregisterNodeFromRegistry(node TaskNode) error {
|
||||
func (r *BackupNodesRegistry) UnregisterNodeFromRegistry(backupNode BackupNode) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node.ID.String(), nodeInfoKeySuffix)
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, backupNode.ID.String(), nodeInfoKeySuffix)
|
||||
counterKey := fmt.Sprintf(
|
||||
"%s%s%s",
|
||||
nodeActiveTasksPrefix,
|
||||
node.ID.String(),
|
||||
nodeActiveTasksSuffix,
|
||||
nodeActiveBackupsPrefix,
|
||||
backupNode.ID.String(),
|
||||
nodeActiveBackupsSuffix,
|
||||
)
|
||||
|
||||
result := r.client.Do(
|
||||
@@ -331,49 +343,49 @@ func (r *TaskNodesRegistry) UnregisterNodeFromRegistry(node TaskNode) error {
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to unregister node %s: %w", node.ID, result.Error())
|
||||
return fmt.Errorf("failed to unregister node %s: %w", backupNode.ID, result.Error())
|
||||
}
|
||||
|
||||
r.logger.Info("Unregistered node from registry", "nodeID", node.ID)
|
||||
r.logger.Info("Unregistered node from registry", "nodeID", backupNode.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) AssignTaskToNode(
|
||||
targetNodeID string,
|
||||
taskID uuid.UUID,
|
||||
func (r *BackupNodesRegistry) AssignBackupToNode(
|
||||
targetNodeID uuid.UUID,
|
||||
backupID uuid.UUID,
|
||||
isCallNotifier bool,
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
message := TaskSubmitMessage{
|
||||
message := BackupSubmitMessage{
|
||||
NodeID: targetNodeID,
|
||||
TaskID: taskID.String(),
|
||||
BackupID: backupID,
|
||||
IsCallNotifier: isCallNotifier,
|
||||
}
|
||||
|
||||
messageJSON, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal task submit message: %w", err)
|
||||
return fmt.Errorf("failed to marshal backup submit message: %w", err)
|
||||
}
|
||||
|
||||
err = r.pubsubTasks.Publish(ctx, taskSubmitChannel, string(messageJSON))
|
||||
err = r.pubsubBackups.Publish(ctx, backupSubmitChannel, string(messageJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to publish task submit message: %w", err)
|
||||
return fmt.Errorf("failed to publish backup submit message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) SubscribeNodeForTasksAssignment(
|
||||
nodeID string,
|
||||
handler func(taskID uuid.UUID, isCallNotifier bool),
|
||||
func (r *BackupNodesRegistry) SubscribeNodeForBackupsAssignment(
|
||||
nodeID uuid.UUID,
|
||||
handler func(backupID uuid.UUID, isCallNotifier bool),
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
wrappedHandler := func(message string) {
|
||||
var msg TaskSubmitMessage
|
||||
var msg BackupSubmitMessage
|
||||
if err := json.Unmarshal([]byte(message), &msg); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal task submit message", "error", err)
|
||||
r.logger.Warn("Failed to unmarshal backup submit message", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -381,108 +393,84 @@ func (r *TaskNodesRegistry) SubscribeNodeForTasksAssignment(
|
||||
return
|
||||
}
|
||||
|
||||
taskID, err := uuid.Parse(msg.TaskID)
|
||||
if err != nil {
|
||||
r.logger.Warn(
|
||||
"Failed to parse task ID from message",
|
||||
"taskId",
|
||||
msg.TaskID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
handler(taskID, msg.IsCallNotifier)
|
||||
handler(msg.BackupID, msg.IsCallNotifier)
|
||||
}
|
||||
|
||||
err := r.pubsubTasks.Subscribe(ctx, taskSubmitChannel, wrappedHandler)
|
||||
err := r.pubsubBackups.Subscribe(ctx, backupSubmitChannel, wrappedHandler)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to task submit channel: %w", err)
|
||||
return fmt.Errorf("failed to subscribe to backup submit channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Subscribed to task submit channel", "nodeID", nodeID)
|
||||
r.logger.Info("Subscribed to backup submit channel", "nodeID", nodeID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) UnsubscribeNodeForTasksAssignments() error {
|
||||
err := r.pubsubTasks.Close()
|
||||
func (r *BackupNodesRegistry) UnsubscribeNodeForBackupsAssignments() error {
|
||||
err := r.pubsubBackups.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unsubscribe from task submit channel: %w", err)
|
||||
return fmt.Errorf("failed to unsubscribe from backup submit channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Unsubscribed from task submit channel")
|
||||
r.logger.Info("Unsubscribed from backup submit channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) PublishTaskCompletion(nodeID string, taskID uuid.UUID) error {
|
||||
func (r *BackupNodesRegistry) PublishBackupCompletion(nodeID uuid.UUID, backupID uuid.UUID) error {
|
||||
ctx := context.Background()
|
||||
|
||||
message := TaskCompletionMessage{
|
||||
NodeID: nodeID,
|
||||
TaskID: taskID.String(),
|
||||
message := BackupCompletionMessage{
|
||||
NodeID: nodeID,
|
||||
BackupID: backupID,
|
||||
}
|
||||
|
||||
messageJSON, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal task completion message: %w", err)
|
||||
return fmt.Errorf("failed to marshal backup completion message: %w", err)
|
||||
}
|
||||
|
||||
err = r.pubsubCompletions.Publish(ctx, taskCompletionChannel, string(messageJSON))
|
||||
err = r.pubsubCompletions.Publish(ctx, backupCompletionChannel, string(messageJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to publish task completion message: %w", err)
|
||||
return fmt.Errorf("failed to publish backup completion message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) SubscribeForTasksCompletions(
|
||||
handler func(nodeID string, taskID uuid.UUID),
|
||||
func (r *BackupNodesRegistry) SubscribeForBackupsCompletions(
|
||||
handler func(nodeID uuid.UUID, backupID uuid.UUID),
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
wrappedHandler := func(message string) {
|
||||
var msg TaskCompletionMessage
|
||||
var msg BackupCompletionMessage
|
||||
if err := json.Unmarshal([]byte(message), &msg); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal task completion message", "error", err)
|
||||
r.logger.Warn("Failed to unmarshal backup completion message", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
taskID, err := uuid.Parse(msg.TaskID)
|
||||
if err != nil {
|
||||
r.logger.Warn(
|
||||
"Failed to parse task ID from completion message",
|
||||
"taskId",
|
||||
msg.TaskID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
handler(msg.NodeID, taskID)
|
||||
handler(msg.NodeID, msg.BackupID)
|
||||
}
|
||||
|
||||
err := r.pubsubCompletions.Subscribe(ctx, taskCompletionChannel, wrappedHandler)
|
||||
err := r.pubsubCompletions.Subscribe(ctx, backupCompletionChannel, wrappedHandler)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to task completion channel: %w", err)
|
||||
return fmt.Errorf("failed to subscribe to backup completion channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Subscribed to task completion channel")
|
||||
r.logger.Info("Subscribed to backup completion channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) UnsubscribeForTasksCompletions() error {
|
||||
func (r *BackupNodesRegistry) UnsubscribeForBackupsCompletions() error {
|
||||
err := r.pubsubCompletions.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unsubscribe from task completion channel: %w", err)
|
||||
return fmt.Errorf("failed to unsubscribe from backup completion channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Unsubscribed from task completion channel")
|
||||
r.logger.Info("Unsubscribed from backup completion channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) extractNodeIDFromKey(key, prefix, suffix string) uuid.UUID {
|
||||
func (r *BackupNodesRegistry) extractNodeIDFromKey(key, prefix, suffix string) uuid.UUID {
|
||||
nodeIDStr := strings.TrimPrefix(key, prefix)
|
||||
nodeIDStr = strings.TrimSuffix(nodeIDStr, suffix)
|
||||
|
||||
@@ -495,7 +483,7 @@ func (r *TaskNodesRegistry) extractNodeIDFromKey(key, prefix, suffix string) uui
|
||||
return nodeID
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) pipelineGetKeys(keys []string) (map[string][]byte, error) {
|
||||
func (r *BackupNodesRegistry) pipelineGetKeys(keys []string) (map[string][]byte, error) {
|
||||
if len(keys) == 0 {
|
||||
return make(map[string][]byte), nil
|
||||
}
|
||||
@@ -529,7 +517,7 @@ func (r *TaskNodesRegistry) pipelineGetKeys(keys []string) (map[string][]byte, e
|
||||
return keyDataMap, nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) parseIntFromBytes(data []byte) (int64, error) {
|
||||
func (r *BackupNodesRegistry) parseIntFromBytes(data []byte) (int64, error) {
|
||||
str := string(data)
|
||||
var count int64
|
||||
_, err := fmt.Sscanf(str, "%d", &count)
|
||||
@@ -539,7 +527,7 @@ func (r *TaskNodesRegistry) parseIntFromBytes(data []byte) (int64, error) {
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) cleanupDeadNodes() error {
|
||||
func (r *BackupNodesRegistry) cleanupDeadNodes() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
@@ -583,13 +571,12 @@ func (r *TaskNodesRegistry) cleanupDeadNodes() error {
|
||||
var deadNodeKeys []string
|
||||
|
||||
for key, data := range keyDataMap {
|
||||
|
||||
// Skip if the key doesn't exist (data is empty)
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var node TaskNode
|
||||
var node BackupNode
|
||||
if err := json.Unmarshal(data, &node); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal node data during cleanup", "key", key, "error", err)
|
||||
continue
|
||||
@@ -603,7 +590,12 @@ func (r *TaskNodesRegistry) cleanupDeadNodes() error {
|
||||
if node.LastHeartbeat.Before(threshold) {
|
||||
nodeID := node.ID.String()
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeID, nodeInfoKeySuffix)
|
||||
statsKey := fmt.Sprintf("%s%s%s", nodeActiveTasksPrefix, nodeID, nodeActiveTasksSuffix)
|
||||
statsKey := fmt.Sprintf(
|
||||
"%s%s%s",
|
||||
nodeActiveBackupsPrefix,
|
||||
nodeID,
|
||||
nodeActiveBackupsSuffix,
|
||||
)
|
||||
|
||||
deadNodeKeys = append(deadNodeKeys, infoKey, statsKey)
|
||||
r.logger.Info(
|
||||
1134
backend/internal/features/backups/backups/backuping/registry_test.go
Normal file
1134
backend/internal/features/backups/backups/backuping/registry_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -2,19 +2,20 @@ package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"databasus-backend/internal/config"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/storages"
|
||||
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
task_registry "databasus-backend/internal/features/tasks/registry"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/period"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
files_utils "databasus-backend/internal/util/files"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -26,66 +27,77 @@ const (
|
||||
type BackupsScheduler struct {
|
||||
backupRepository *backups_core.BackupRepository
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
storageService *storages.StorageService
|
||||
taskCancelManager *task_cancellation.TaskCancelManager
|
||||
tasksRegistry *task_registry.TaskNodesRegistry
|
||||
backupNodesRegistry *BackupNodesRegistry
|
||||
databaseService *databases.DatabaseService
|
||||
|
||||
lastBackupTime time.Time
|
||||
logger *slog.Logger
|
||||
|
||||
backupToNodeRelations map[uuid.UUID]BackupToNodeRelation
|
||||
backuperNode *BackuperNode
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) Run(ctx context.Context) {
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
|
||||
if config.GetEnv().IsManyNodesMode {
|
||||
// wait other nodes to start
|
||||
time.Sleep(schedulerStartupDelay)
|
||||
}
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
|
||||
if err := s.failBackupsInProgress(); err != nil {
|
||||
s.logger.Error("Failed to fail backups in progress", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
|
||||
if err := s.tasksRegistry.SubscribeForTasksCompletions(s.onBackupCompleted); err != nil {
|
||||
s.logger.Error("Failed to subscribe to backup completions", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := s.tasksRegistry.UnsubscribeForTasksCompletions(); err != nil {
|
||||
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
|
||||
if config.GetEnv().IsManyNodesMode {
|
||||
// wait other nodes to start
|
||||
time.Sleep(schedulerStartupDelay)
|
||||
}
|
||||
}()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if err := s.failBackupsInProgress(); err != nil {
|
||||
s.logger.Error("Failed to fail backups in progress", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(schedulerTickerInterval)
|
||||
defer ticker.Stop()
|
||||
err := s.backupNodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to subscribe to backup completions", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
defer func() {
|
||||
if err := s.backupNodesRegistry.UnsubscribeForBackupsCompletions(); err != nil {
|
||||
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.cleanOldBackups(); err != nil {
|
||||
s.logger.Error("Failed to clean old backups", "error", err)
|
||||
}
|
||||
|
||||
if err := s.checkDeadNodesAndFailBackups(); err != nil {
|
||||
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
|
||||
}
|
||||
|
||||
if err := s.runPendingBackups(); err != nil {
|
||||
s.logger.Error("Failed to run pending backups", "error", err)
|
||||
}
|
||||
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(schedulerTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.checkDeadNodesAndFailBackups(); err != nil {
|
||||
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
|
||||
}
|
||||
|
||||
if err := s.runPendingBackups(); err != nil {
|
||||
s.logger.Error("Failed to run pending backups", "error", err)
|
||||
}
|
||||
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,6 +106,255 @@ func (s *BackupsScheduler) IsSchedulerRunning() bool {
|
||||
return s.lastBackupTime.After(time.Now().UTC().Add(-schedulerHealthcheckThreshold))
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) IsBackupNodesAvailable() bool {
|
||||
nodes, err := s.backupNodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get available nodes for health check", "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
return len(nodes) > 0
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) StartBackup(database *databases.Database, isCallNotifier bool) {
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if backupConfig.StorageID == nil {
|
||||
s.logger.Error("Backup config storage ID is nil", "databaseId", database.ID)
|
||||
return
|
||||
}
|
||||
|
||||
// Check for existing in-progress backups
|
||||
inProgressBackups, err := s.backupRepository.FindByDatabaseIdAndStatus(
|
||||
database.ID,
|
||||
backups_core.BackupStatusInProgress,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to check for in-progress backups",
|
||||
"databaseId",
|
||||
database.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if len(inProgressBackups) > 0 {
|
||||
s.logger.Warn(
|
||||
"Backup already in progress for database, skipping new backup",
|
||||
"databaseId",
|
||||
database.ID,
|
||||
"existingBackupId",
|
||||
inProgressBackups[0].ID,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
leastBusyNodeID, err := s.calculateLeastBusyNode()
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to calculate least busy node",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
backupID := uuid.New()
|
||||
timestamp := time.Now().UTC()
|
||||
|
||||
backup := &backups_core.Backup{
|
||||
ID: backupID,
|
||||
FileName: fmt.Sprintf(
|
||||
"%s-%s-%s",
|
||||
files_utils.SanitizeFilename(database.Name),
|
||||
timestamp.Format("20060102-150405"),
|
||||
backupID.String(),
|
||||
),
|
||||
DatabaseID: backupConfig.DatabaseID,
|
||||
StorageID: *backupConfig.StorageID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
BackupSizeMb: 0,
|
||||
CreatedAt: timestamp,
|
||||
}
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to save backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.backupNodesRegistry.IncrementBackupsInProgress(*leastBusyNodeID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to increment backups in progress",
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.backupNodesRegistry.AssignBackupToNode(*leastBusyNodeID, backup.ID, isCallNotifier); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to submit backup",
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
if decrementErr := s.backupNodesRegistry.DecrementBackupsInProgress(*leastBusyNodeID); decrementErr != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement backups in progress after submit failure",
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
"error",
|
||||
decrementErr,
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if relation, exists := s.backupToNodeRelations[*leastBusyNodeID]; exists {
|
||||
relation.BackupsIDs = append(relation.BackupsIDs, backup.ID)
|
||||
s.backupToNodeRelations[*leastBusyNodeID] = relation
|
||||
} else {
|
||||
s.backupToNodeRelations[*leastBusyNodeID] = BackupToNodeRelation{
|
||||
*leastBusyNodeID,
|
||||
[]uuid.UUID{backup.ID},
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Info(
|
||||
"Successfully triggered scheduled backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
)
|
||||
}
|
||||
|
||||
// GetRemainedBackupTryCount returns the number of remaining backup tries for a given backup.
|
||||
// If the backup is not failed or the backup config does not allow retries, it returns 0.
|
||||
// If the backup is failed and the backup config allows retries, it returns the number of remaining tries.
|
||||
// If the backup is failed and the backup config does not allow retries, it returns 0.
|
||||
func (s *BackupsScheduler) GetRemainedBackupTryCount(lastBackup *backups_core.Backup) int {
|
||||
if lastBackup == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
if lastBackup.Status != backups_core.BackupStatusFailed {
|
||||
return 0
|
||||
}
|
||||
|
||||
if lastBackup.IsSkipRetry {
|
||||
return 0
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(lastBackup.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
if !backupConfig.IsRetryIfFailed {
|
||||
return 0
|
||||
}
|
||||
|
||||
maxFailedTriesCount := backupConfig.MaxFailedTriesCount
|
||||
|
||||
lastBackups, err := s.backupRepository.FindByDatabaseIDWithLimit(
|
||||
lastBackup.DatabaseID,
|
||||
maxFailedTriesCount,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to find last backups by database ID", "error", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
lastFailedBackups := make([]*backups_core.Backup, 0)
|
||||
|
||||
for _, backup := range lastBackups {
|
||||
if backup.Status == backups_core.BackupStatusFailed {
|
||||
lastFailedBackups = append(lastFailedBackups, backup)
|
||||
}
|
||||
}
|
||||
|
||||
return maxFailedTriesCount - len(lastFailedBackups)
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) runPendingBackups() error {
|
||||
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
if backupConfig.BackupInterval == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
lastBackup, err := s.backupRepository.FindLastByDatabaseID(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to get last backup for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
var lastBackupTime *time.Time
|
||||
if lastBackup != nil {
|
||||
lastBackupTime = &lastBackup.CreatedAt
|
||||
}
|
||||
|
||||
remainedBackupTryCount := s.GetRemainedBackupTryCount(lastBackup)
|
||||
|
||||
if backupConfig.BackupInterval.ShouldTriggerBackup(time.Now().UTC(), lastBackupTime) ||
|
||||
remainedBackupTryCount > 0 {
|
||||
s.logger.Info(
|
||||
"Triggering scheduled backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"intervalType",
|
||||
backupConfig.BackupInterval.Interval,
|
||||
)
|
||||
|
||||
database, err := s.databaseService.GetDatabaseByID(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get database by ID", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
s.StartBackup(database, remainedBackupTryCount == 1)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) failBackupsInProgress() error {
|
||||
backupsInProgress, err := s.backupRepository.FindByStatus(backups_core.BackupStatusInProgress)
|
||||
if err != nil {
|
||||
@@ -137,268 +398,8 @@ func (s *BackupsScheduler) failBackupsInProgress() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool) {
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(databaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if backupConfig.StorageID == nil {
|
||||
s.logger.Error("Backup config storage ID is nil", "databaseId", databaseID)
|
||||
return
|
||||
}
|
||||
|
||||
leastBusyNodeID, err := s.calculateLeastBusyNode()
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to calculate least busy node",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: backupConfig.DatabaseID,
|
||||
StorageID: *backupConfig.StorageID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
BackupSizeMb: 0,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to save backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.tasksRegistry.IncrementTasksInProgress(leastBusyNodeID.String()); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to increment backups in progress",
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.tasksRegistry.AssignTaskToNode(leastBusyNodeID.String(), backup.ID, isCallNotifier); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to submit backup",
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
if decrementErr := s.tasksRegistry.DecrementTasksInProgress(leastBusyNodeID.String()); decrementErr != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement backups in progress after submit failure",
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
"error",
|
||||
decrementErr,
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if relation, exists := s.backupToNodeRelations[*leastBusyNodeID]; exists {
|
||||
relation.BackupsIDs = append(relation.BackupsIDs, backup.ID)
|
||||
s.backupToNodeRelations[*leastBusyNodeID] = relation
|
||||
} else {
|
||||
s.backupToNodeRelations[*leastBusyNodeID] = BackupToNodeRelation{
|
||||
NodeID: *leastBusyNodeID,
|
||||
BackupsIDs: []uuid.UUID{backup.ID},
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Info(
|
||||
"Successfully triggered scheduled backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
)
|
||||
}
|
||||
|
||||
// GetRemainedBackupTryCount returns the number of remaining backup tries for a given backup.
|
||||
// If the backup is not failed or the backup config does not allow retries, it returns 0.
|
||||
// If the backup is failed and the backup config allows retries, it returns the number of remaining tries.
|
||||
// If the backup is failed and the backup config does not allow retries, it returns 0.
|
||||
func (s *BackupsScheduler) GetRemainedBackupTryCount(lastBackup *backups_core.Backup) int {
|
||||
if lastBackup == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
if lastBackup.Status != backups_core.BackupStatusFailed {
|
||||
return 0
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(lastBackup.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
if !backupConfig.IsRetryIfFailed {
|
||||
return 0
|
||||
}
|
||||
|
||||
maxFailedTriesCount := backupConfig.MaxFailedTriesCount
|
||||
|
||||
lastBackups, err := s.backupRepository.FindByDatabaseIDWithLimit(
|
||||
lastBackup.DatabaseID,
|
||||
maxFailedTriesCount,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to find last backups by database ID", "error", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
lastFailedBackups := make([]*backups_core.Backup, 0)
|
||||
|
||||
for _, backup := range lastBackups {
|
||||
if backup.Status == backups_core.BackupStatusFailed {
|
||||
lastFailedBackups = append(lastFailedBackups, backup)
|
||||
}
|
||||
}
|
||||
|
||||
return maxFailedTriesCount - len(lastFailedBackups)
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) cleanOldBackups() error {
|
||||
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
backupStorePeriod := backupConfig.StorePeriod
|
||||
|
||||
if backupStorePeriod == period.PeriodForever {
|
||||
continue
|
||||
}
|
||||
|
||||
storeDuration := backupStorePeriod.ToDuration()
|
||||
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
|
||||
|
||||
oldBackups, err := s.backupRepository.FindBackupsBeforeDate(
|
||||
backupConfig.DatabaseID,
|
||||
dateBeforeBackupsShouldBeDeleted,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to find old backups for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, backup := range oldBackups {
|
||||
storage, err := s.storageService.GetStorageByID(backup.StorageID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to get storage by ID",
|
||||
"storageId",
|
||||
backup.StorageID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
err = storage.DeleteFile(encryptor, backup.ID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to delete backup file", "backupId", backup.ID, "error", err)
|
||||
}
|
||||
|
||||
if err := s.backupRepository.DeleteByID(backup.ID); err != nil {
|
||||
s.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
s.logger.Info(
|
||||
"Deleted old backup",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) runPendingBackups() error {
|
||||
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
if backupConfig.BackupInterval == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
lastBackup, err := s.backupRepository.FindLastByDatabaseID(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to get last backup for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
var lastBackupTime *time.Time
|
||||
if lastBackup != nil {
|
||||
lastBackupTime = &lastBackup.CreatedAt
|
||||
}
|
||||
|
||||
remainedBackupTryCount := s.GetRemainedBackupTryCount(lastBackup)
|
||||
|
||||
if backupConfig.BackupInterval.ShouldTriggerBackup(time.Now().UTC(), lastBackupTime) ||
|
||||
remainedBackupTryCount > 0 {
|
||||
s.logger.Info(
|
||||
"Triggering scheduled backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"intervalType",
|
||||
backupConfig.BackupInterval.Interval,
|
||||
)
|
||||
|
||||
s.StartBackup(backupConfig.DatabaseID, remainedBackupTryCount == 1)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
|
||||
nodes, err := s.tasksRegistry.GetAvailableNodes()
|
||||
nodes, err := s.backupNodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get available nodes: %w", err)
|
||||
}
|
||||
@@ -407,17 +408,17 @@ func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
|
||||
return nil, fmt.Errorf("no nodes available")
|
||||
}
|
||||
|
||||
stats, err := s.tasksRegistry.GetNodesStats()
|
||||
stats, err := s.backupNodesRegistry.GetBackupNodesStats()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get backup nodes stats: %w", err)
|
||||
}
|
||||
|
||||
statsMap := make(map[uuid.UUID]int)
|
||||
for _, stat := range stats {
|
||||
statsMap[stat.ID] = stat.ActiveTasks
|
||||
statsMap[stat.ID] = stat.ActiveBackups
|
||||
}
|
||||
|
||||
var bestNode *task_registry.TaskNode
|
||||
var bestNode *BackupNode
|
||||
var bestScore float64 = -1
|
||||
|
||||
for i := range nodes {
|
||||
@@ -445,21 +446,9 @@ func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
|
||||
return &bestNode.ID, nil
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUID) {
|
||||
nodeID, err := uuid.Parse(nodeIDStr)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to parse node ID from completion message",
|
||||
"nodeId",
|
||||
nodeIDStr,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) onBackupCompleted(nodeID uuid.UUID, backupID uuid.UUID) {
|
||||
// Verify this task is actually a backup (registry contains multiple task types)
|
||||
_, err = s.backupRepository.FindByID(backupID)
|
||||
_, err := s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
// Not a backup task, ignore it
|
||||
return
|
||||
@@ -505,7 +494,7 @@ func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUI
|
||||
s.backupToNodeRelations[nodeID] = relation
|
||||
}
|
||||
|
||||
if err := s.tasksRegistry.DecrementTasksInProgress(nodeIDStr); err != nil {
|
||||
if err := s.backupNodesRegistry.DecrementBackupsInProgress(nodeID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement backups in progress",
|
||||
"nodeId",
|
||||
@@ -519,7 +508,7 @@ func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUI
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) checkDeadNodesAndFailBackups() error {
|
||||
nodes, err := s.tasksRegistry.GetAvailableNodes()
|
||||
nodes, err := s.backupNodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get available nodes: %w", err)
|
||||
}
|
||||
@@ -575,7 +564,7 @@ func (s *BackupsScheduler) checkDeadNodesAndFailBackups() error {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := s.tasksRegistry.DecrementTasksInProgress(nodeID.String()); err != nil {
|
||||
if err := s.backupNodesRegistry.DecrementBackupsInProgress(nodeID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement backups in progress for dead node",
|
||||
"nodeId",
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"databasus-backend/internal/features/intervals"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
task_registry "databasus-backend/internal/features/tasks/registry"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
@@ -58,7 +57,8 @@ func Test_RunPendingBackups_WhenLastBackupWasYesterday_CreatesNewBackup(t *testi
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
|
||||
backupConfig.RetentionTimePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
@@ -127,7 +127,8 @@ func Test_RunPendingBackups_WhenLastBackupWasRecentlyCompleted_SkipsBackup(t *te
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
|
||||
backupConfig.RetentionTimePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
@@ -195,7 +196,8 @@ func Test_RunPendingBackups_WhenLastBackupFailedAndRetriesDisabled_SkipsBackup(t
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
|
||||
backupConfig.RetentionTimePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
backupConfig.IsRetryIfFailed = false
|
||||
@@ -267,7 +269,8 @@ func Test_RunPendingBackups_WhenLastBackupFailedAndRetriesEnabled_CreatesNewBack
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
|
||||
backupConfig.RetentionTimePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
backupConfig.IsRetryIfFailed = true
|
||||
@@ -340,7 +343,8 @@ func Test_RunPendingBackups_WhenFailedBackupsExceedMaxRetries_SkipsBackup(t *tes
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
|
||||
backupConfig.RetentionTimePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
backupConfig.IsRetryIfFailed = true
|
||||
@@ -411,7 +415,8 @@ func Test_RunPendingBackups_WhenBackupsDisabled_SkipsBackup(t *testing.T) {
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = false
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
|
||||
backupConfig.RetentionTimePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
@@ -466,7 +471,7 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
|
||||
|
||||
// Clean up mock node
|
||||
if mockNodeID != uuid.Nil {
|
||||
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: mockNodeID})
|
||||
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: mockNodeID})
|
||||
}
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
@@ -480,7 +485,8 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
|
||||
backupConfig.RetentionTimePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
@@ -493,7 +499,7 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Scheduler assigns backup to mock node
|
||||
GetBackupsScheduler().StartBackup(database.ID, false)
|
||||
GetBackupsScheduler().StartBackup(database, false)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -502,12 +508,12 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
|
||||
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status)
|
||||
|
||||
// Verify Valkey counter was incremented when backup was assigned
|
||||
stats, err := nodesRegistry.GetNodesStats()
|
||||
stats, err := backupNodesRegistry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
foundStat := false
|
||||
for _, stat := range stats {
|
||||
if stat.ID == mockNodeID {
|
||||
assert.Equal(t, 1, stat.ActiveTasks)
|
||||
assert.Equal(t, 1, stat.ActiveBackups)
|
||||
foundStat = true
|
||||
break
|
||||
}
|
||||
@@ -532,11 +538,11 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
|
||||
assert.Contains(t, *backups[0].FailMessage, "node unavailability")
|
||||
|
||||
// Verify Valkey counter was decremented after backup failed
|
||||
stats, err = nodesRegistry.GetNodesStats()
|
||||
stats, err = backupNodesRegistry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
for _, stat := range stats {
|
||||
if stat.ID == mockNodeID {
|
||||
assert.Equal(t, 0, stat.ActiveTasks)
|
||||
assert.Equal(t, 0, stat.ActiveBackups)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -569,7 +575,7 @@ func Test_OnBackupCompleted_WhenTaskIsNotBackup_SkipsProcessing(t *testing.T) {
|
||||
|
||||
// Clean up mock node
|
||||
if mockNodeID != uuid.Nil {
|
||||
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: mockNodeID})
|
||||
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: mockNodeID})
|
||||
}
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
@@ -583,7 +589,8 @@ func Test_OnBackupCompleted_WhenTaskIsNotBackup_SkipsProcessing(t *testing.T) {
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
|
||||
backupConfig.RetentionTimePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
@@ -596,7 +603,7 @@ func Test_OnBackupCompleted_WhenTaskIsNotBackup_SkipsProcessing(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Start a backup and assign it to the node
|
||||
GetBackupsScheduler().StartBackup(database.ID, false)
|
||||
GetBackupsScheduler().StartBackup(database, false)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -605,12 +612,12 @@ func Test_OnBackupCompleted_WhenTaskIsNotBackup_SkipsProcessing(t *testing.T) {
|
||||
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status)
|
||||
|
||||
// Get initial state of the registry
|
||||
initialStats, err := nodesRegistry.GetNodesStats()
|
||||
initialStats, err := backupNodesRegistry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
var initialActiveTasks int
|
||||
for _, stat := range initialStats {
|
||||
if stat.ID == mockNodeID {
|
||||
initialActiveTasks = stat.ActiveTasks
|
||||
initialActiveTasks = stat.ActiveBackups
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -618,16 +625,16 @@ func Test_OnBackupCompleted_WhenTaskIsNotBackup_SkipsProcessing(t *testing.T) {
|
||||
|
||||
// Call onBackupCompleted with a random UUID (not a backup ID)
|
||||
nonBackupTaskID := uuid.New()
|
||||
GetBackupsScheduler().onBackupCompleted(mockNodeID.String(), nonBackupTaskID)
|
||||
GetBackupsScheduler().onBackupCompleted(mockNodeID, nonBackupTaskID)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify: Active tasks counter should remain the same (not decremented)
|
||||
stats, err := nodesRegistry.GetNodesStats()
|
||||
stats, err := backupNodesRegistry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
for _, stat := range stats {
|
||||
if stat.ID == mockNodeID {
|
||||
assert.Equal(t, initialActiveTasks, stat.ActiveTasks,
|
||||
assert.Equal(t, initialActiveTasks, stat.ActiveBackups,
|
||||
"Active tasks should not change for non-backup task")
|
||||
}
|
||||
}
|
||||
@@ -658,9 +665,9 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
|
||||
|
||||
defer func() {
|
||||
// Clean up all mock nodes
|
||||
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: node1ID})
|
||||
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: node2ID})
|
||||
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: node3ID})
|
||||
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node1ID})
|
||||
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node2ID})
|
||||
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node3ID})
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
@@ -672,17 +679,17 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
for range 5 {
|
||||
err = nodesRegistry.IncrementTasksInProgress(node1ID.String())
|
||||
err = backupNodesRegistry.IncrementBackupsInProgress(node1ID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
for range 2 {
|
||||
err = nodesRegistry.IncrementTasksInProgress(node2ID.String())
|
||||
err = backupNodesRegistry.IncrementBackupsInProgress(node2ID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
for range 8 {
|
||||
err = nodesRegistry.IncrementTasksInProgress(node3ID.String())
|
||||
err = backupNodesRegistry.IncrementBackupsInProgress(node3ID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -701,8 +708,8 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
|
||||
|
||||
defer func() {
|
||||
// Clean up all mock nodes
|
||||
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: node100MBsID})
|
||||
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: node50MBsID})
|
||||
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node100MBsID})
|
||||
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node50MBsID})
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
@@ -712,11 +719,11 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
for range 10 {
|
||||
err = nodesRegistry.IncrementTasksInProgress(node100MBsID.String())
|
||||
err = backupNodesRegistry.IncrementBackupsInProgress(node100MBsID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
err = nodesRegistry.IncrementTasksInProgress(node50MBsID.String())
|
||||
err = backupNodesRegistry.IncrementBackupsInProgress(node50MBsID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
leastBusyNodeID, err := GetBackupsScheduler().calculateLeastBusyNode()
|
||||
@@ -760,7 +767,8 @@ func Test_FailBackupsInProgress_WhenSchedulerStarts_CancelsBackupsAndUpdatesStat
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
|
||||
backupConfig.RetentionTimePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
@@ -836,7 +844,8 @@ func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
// Start scheduler so it can handle task completions
|
||||
schedulerCancel := StartSchedulerForTest(t)
|
||||
scheduler := CreateTestScheduler()
|
||||
schedulerCancel := StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
@@ -872,7 +881,8 @@ func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
|
||||
backupConfig.RetentionTimePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
@@ -880,19 +890,19 @@ func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Get initial active task count
|
||||
stats, err := nodesRegistry.GetNodesStats()
|
||||
stats, err := backupNodesRegistry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
var initialActiveTasks int
|
||||
for _, stat := range stats {
|
||||
if stat.ID == backuperNode.nodeID {
|
||||
initialActiveTasks = stat.ActiveTasks
|
||||
initialActiveTasks = stat.ActiveBackups
|
||||
break
|
||||
}
|
||||
}
|
||||
t.Logf("Initial active tasks: %d", initialActiveTasks)
|
||||
|
||||
// Start backup
|
||||
GetBackupsScheduler().StartBackup(database.ID, false)
|
||||
scheduler.StartBackup(database, false)
|
||||
|
||||
// Wait for backup to complete
|
||||
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
|
||||
@@ -913,12 +923,12 @@ func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T
|
||||
assert.True(t, decreased, "Active task count should have decreased after backup completion")
|
||||
|
||||
// Verify final active task count equals initial count
|
||||
finalStats, err := nodesRegistry.GetNodesStats()
|
||||
finalStats, err := backupNodesRegistry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
for _, stat := range finalStats {
|
||||
if stat.ID == backuperNode.nodeID {
|
||||
t.Logf("Final active tasks: %d", stat.ActiveTasks)
|
||||
assert.Equal(t, initialActiveTasks, stat.ActiveTasks,
|
||||
t.Logf("Final active tasks: %d", stat.ActiveBackups)
|
||||
assert.Equal(t, initialActiveTasks, stat.ActiveBackups,
|
||||
"Active task count should return to initial value after backup completion")
|
||||
break
|
||||
}
|
||||
@@ -931,7 +941,8 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
// Start scheduler so it can handle task completions
|
||||
schedulerCancel := StartSchedulerForTest(t)
|
||||
scheduler := CreateTestScheduler()
|
||||
schedulerCancel := StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
@@ -974,7 +985,8 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
|
||||
backupConfig.RetentionTimePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
@@ -982,19 +994,19 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Get initial active task count
|
||||
stats, err := nodesRegistry.GetNodesStats()
|
||||
stats, err := backupNodesRegistry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
var initialActiveTasks int
|
||||
for _, stat := range stats {
|
||||
if stat.ID == backuperNode.nodeID {
|
||||
initialActiveTasks = stat.ActiveTasks
|
||||
initialActiveTasks = stat.ActiveBackups
|
||||
break
|
||||
}
|
||||
}
|
||||
t.Logf("Initial active tasks: %d", initialActiveTasks)
|
||||
|
||||
// Start backup
|
||||
GetBackupsScheduler().StartBackup(database.ID, false)
|
||||
scheduler.StartBackup(database, false)
|
||||
|
||||
// Wait for backup to fail
|
||||
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
|
||||
@@ -1019,12 +1031,12 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
|
||||
assert.True(t, decreased, "Active task count should have decreased after backup failure")
|
||||
|
||||
// Verify final active task count equals initial count
|
||||
finalStats, err := nodesRegistry.GetNodesStats()
|
||||
finalStats, err := backupNodesRegistry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
for _, stat := range finalStats {
|
||||
if stat.ID == backuperNode.nodeID {
|
||||
t.Logf("Final active tasks: %d", stat.ActiveTasks)
|
||||
assert.Equal(t, initialActiveTasks, stat.ActiveTasks,
|
||||
t.Logf("Final active tasks: %d", stat.ActiveBackups)
|
||||
assert.Equal(t, initialActiveTasks, stat.ActiveBackups,
|
||||
"Active task count should return to initial value after backup failure")
|
||||
break
|
||||
}
|
||||
@@ -1032,3 +1044,293 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_StartBackup_WhenBackupAlreadyInProgress_SkipsNewBackup(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
|
||||
backupConfig.RetentionTimePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create an in-progress backup manually
|
||||
inProgressBackup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
BackupSizeMb: 0,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err = backupRepository.Save(inProgressBackup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Try to start a new backup - should be skipped
|
||||
GetBackupsScheduler().StartBackup(database, false)
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Verify only 1 backup exists (the original in-progress one)
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1)
|
||||
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status)
|
||||
assert.Equal(t, inProgressBackup.ID, backups[0].ID)
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_RunPendingBackups_WhenLastBackupFailedWithIsSkipRetry_SkipsBackupEvenWithRetriesEnabled(
|
||||
t *testing.T,
|
||||
) {
|
||||
cache_utils.ClearAllCache()
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Enable backups with retries enabled and high retry count
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
|
||||
backupConfig.RetentionTimePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
backupConfig.IsRetryIfFailed = true
|
||||
backupConfig.MaxFailedTriesCount = 5
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create a failed backup with IsSkipRetry set to true
|
||||
failMessage := "backup failed due to size limit exceeded"
|
||||
backupRepository.Save(&backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: backups_core.BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
IsSkipRetry: true,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
})
|
||||
|
||||
// Verify GetRemainedBackupTryCount returns 0 even though retries are enabled
|
||||
lastBackup, err := backupRepository.FindLastByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, lastBackup)
|
||||
|
||||
remainedTries := GetBackupsScheduler().GetRemainedBackupTryCount(lastBackup)
|
||||
assert.Equal(t, 0, remainedTries, "Should return 0 tries when IsSkipRetry is true")
|
||||
|
||||
// Run the scheduler
|
||||
GetBackupsScheduler().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify no new backup was created (still only 1 backup exists)
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1, "No retry should be attempted when IsSkipRetry is true")
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_StartBackup_When2BackupsStartedForDifferentDatabases_BothUseCasesAreCalled(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
// Create mock tracking use case
|
||||
mockUseCase := NewMockTrackingBackupUsecase()
|
||||
|
||||
// Create BackuperNode with mock use case
|
||||
backuperNode := CreateTestBackuperNodeWithUseCase(mockUseCase)
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
// Create scheduler
|
||||
scheduler := CreateTestScheduler()
|
||||
schedulerCancel := StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
// Setup test data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
|
||||
// Create 2 separate databases
|
||||
database1 := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
database2 := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
// Cleanup backups for database1
|
||||
backups1, _ := backupRepository.FindByDatabaseID(database1.ID)
|
||||
for _, backup := range backups1 {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
// Cleanup backups for database2
|
||||
backups2, _ := backupRepository.FindByDatabaseID(database2.ID)
|
||||
for _, backup := range backups2 {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database1)
|
||||
databases.RemoveTestDatabase(database2)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Enable backups for database1
|
||||
backupConfig1, err := backups_config.GetBackupConfigService().
|
||||
GetBackupConfigByDbId(database1.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig1.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig1.IsBackupsEnabled = true
|
||||
backupConfig1.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
|
||||
backupConfig1.RetentionTimePeriod = period.PeriodWeek
|
||||
backupConfig1.Storage = storage
|
||||
backupConfig1.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Enable backups for database2
|
||||
backupConfig2, err := backups_config.GetBackupConfigService().
|
||||
GetBackupConfigByDbId(database2.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backupConfig2.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig2.IsBackupsEnabled = true
|
||||
backupConfig2.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
|
||||
backupConfig2.RetentionTimePeriod = period.PeriodWeek
|
||||
backupConfig2.Storage = storage
|
||||
backupConfig2.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Start 2 backups simultaneously
|
||||
t.Log("Starting backup for database1")
|
||||
scheduler.StartBackup(database1, false)
|
||||
|
||||
t.Log("Starting backup for database2")
|
||||
scheduler.StartBackup(database2, false)
|
||||
|
||||
// Wait up to 10 seconds for both backups to complete
|
||||
t.Log("Waiting for both backups to complete...")
|
||||
|
||||
success := assert.Eventually(t, func() bool {
|
||||
callCount := mockUseCase.GetCallCount()
|
||||
t.Logf("Current call count: %d/2", callCount)
|
||||
return callCount == 2
|
||||
}, 10*time.Second, 200*time.Millisecond, "Both use cases should be called within 10 seconds")
|
||||
|
||||
if !success {
|
||||
t.Logf("Test failed: Only %d out of 2 use cases were called", mockUseCase.GetCallCount())
|
||||
}
|
||||
|
||||
// Verify both backup IDs were received
|
||||
calledBackupIDs := mockUseCase.GetCalledBackupIDs()
|
||||
t.Logf("Called backup IDs: %v", calledBackupIDs)
|
||||
assert.Len(t, calledBackupIDs, 2, "Both backup IDs should be tracked")
|
||||
|
||||
// Verify both backups exist in repository and are completed
|
||||
backups1, err := backupRepository.FindByDatabaseID(database1.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups1, 1, "Database1 should have 1 backup")
|
||||
if len(backups1) > 0 {
|
||||
t.Logf("Database1 backup status: %s", backups1[0].Status)
|
||||
}
|
||||
|
||||
backups2, err := backupRepository.FindByDatabaseID(database2.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups2, 1, "Database2 should have 1 backup")
|
||||
if len(backups2) > 0 {
|
||||
t.Logf("Database2 backup status: %s", backups2[0].Status)
|
||||
}
|
||||
|
||||
// Verify both backups completed successfully
|
||||
if len(backups1) > 0 {
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backups1[0].Status,
|
||||
"Database1 backup should be completed")
|
||||
}
|
||||
|
||||
if len(backups2) > 0 {
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backups2[0].Status,
|
||||
"Database2 backup should be completed")
|
||||
}
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package backuping
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -12,7 +14,6 @@ import (
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
task_registry "databasus-backend/internal/features/tasks/registry"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
@@ -36,19 +37,56 @@ func CreateTestRouter() *gin.Engine {
|
||||
|
||||
func CreateTestBackuperNode() *BackuperNode {
|
||||
return &BackuperNode{
|
||||
databases.GetDatabaseService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
taskCancelManager,
|
||||
nodesRegistry,
|
||||
logger.GetLogger(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
uuid.New(),
|
||||
time.Time{},
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
workspaceService: workspaces_services.GetWorkspaceService(),
|
||||
backupRepository: backupRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
notificationSender: notifiers.GetNotifierService(),
|
||||
backupCancelManager: taskCancelManager,
|
||||
backupNodesRegistry: backupNodesRegistry,
|
||||
logger: logger.GetLogger(),
|
||||
createBackupUseCase: usecases.GetCreateBackupUsecase(),
|
||||
nodeID: uuid.New(),
|
||||
lastHeartbeat: time.Time{},
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateTestBackuperNodeWithUseCase(useCase backups_core.CreateBackupUsecase) *BackuperNode {
|
||||
return &BackuperNode{
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
workspaceService: workspaces_services.GetWorkspaceService(),
|
||||
backupRepository: backupRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
notificationSender: notifiers.GetNotifierService(),
|
||||
backupCancelManager: taskCancelManager,
|
||||
backupNodesRegistry: backupNodesRegistry,
|
||||
logger: logger.GetLogger(),
|
||||
createBackupUseCase: useCase,
|
||||
nodeID: uuid.New(),
|
||||
lastHeartbeat: time.Time{},
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateTestScheduler() *BackupsScheduler {
|
||||
return &BackupsScheduler{
|
||||
backupRepository: backupRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
taskCancelManager: taskCancelManager,
|
||||
backupNodesRegistry: backupNodesRegistry,
|
||||
lastBackupTime: time.Now().UTC(),
|
||||
logger: logger.GetLogger(),
|
||||
backupToNodeRelations: make(map[uuid.UUID]BackupToNodeRelation),
|
||||
backuperNode: CreateTestBackuperNode(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,7 +152,7 @@ func StartBackuperNodeForTest(t *testing.T, backuperNode *BackuperNode) context.
|
||||
// Poll registry for node presence instead of fixed sleep
|
||||
deadline := time.Now().UTC().Add(5 * time.Second)
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
nodes, err := nodesRegistry.GetAvailableNodes()
|
||||
nodes, err := backupNodesRegistry.GetAvailableNodes()
|
||||
if err == nil {
|
||||
for _, node := range nodes {
|
||||
if node.ID == backuperNode.nodeID {
|
||||
@@ -142,13 +180,13 @@ func StartBackuperNodeForTest(t *testing.T, backuperNode *BackuperNode) context.
|
||||
// StartSchedulerForTest starts the BackupsScheduler in a goroutine for testing.
|
||||
// The scheduler subscribes to task completions and manages backup lifecycle.
|
||||
// Returns a context cancel function that should be deferred to stop the scheduler.
|
||||
func StartSchedulerForTest(t *testing.T) context.CancelFunc {
|
||||
func StartSchedulerForTest(t *testing.T, scheduler *BackupsScheduler) context.CancelFunc {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
GetBackupsScheduler().Run(ctx)
|
||||
scheduler.Run(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -175,7 +213,7 @@ func StopBackuperNodeForTest(t *testing.T, cancel context.CancelFunc, backuperNo
|
||||
// Wait for node to unregister from registry
|
||||
deadline := time.Now().UTC().Add(2 * time.Second)
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
nodes, err := nodesRegistry.GetAvailableNodes()
|
||||
nodes, err := backupNodesRegistry.GetAvailableNodes()
|
||||
if err == nil {
|
||||
found := false
|
||||
for _, node := range nodes {
|
||||
@@ -196,13 +234,13 @@ func StopBackuperNodeForTest(t *testing.T, cancel context.CancelFunc, backuperNo
|
||||
}
|
||||
|
||||
func CreateMockNodeInRegistry(nodeID uuid.UUID, throughputMBs int, lastHeartbeat time.Time) error {
|
||||
backupNode := task_registry.TaskNode{
|
||||
backupNode := BackupNode{
|
||||
ID: nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: lastHeartbeat,
|
||||
}
|
||||
|
||||
return nodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
|
||||
return backupNodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
|
||||
}
|
||||
|
||||
func UpdateNodeHeartbeatDirectly(
|
||||
@@ -210,17 +248,17 @@ func UpdateNodeHeartbeatDirectly(
|
||||
throughputMBs int,
|
||||
lastHeartbeat time.Time,
|
||||
) error {
|
||||
backupNode := task_registry.TaskNode{
|
||||
backupNode := BackupNode{
|
||||
ID: nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: lastHeartbeat,
|
||||
}
|
||||
|
||||
return nodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
|
||||
return backupNodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
|
||||
}
|
||||
|
||||
func GetNodeFromRegistry(nodeID uuid.UUID) (*task_registry.TaskNode, error) {
|
||||
nodes, err := nodesRegistry.GetAvailableNodes()
|
||||
func GetNodeFromRegistry(nodeID uuid.UUID) (*BackupNode, error) {
|
||||
nodes, err := backupNodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -246,7 +284,7 @@ func WaitForActiveTasksDecrease(
|
||||
deadline := time.Now().UTC().Add(timeout)
|
||||
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
stats, err := nodesRegistry.GetNodesStats()
|
||||
stats, err := backupNodesRegistry.GetBackupNodesStats()
|
||||
if err != nil {
|
||||
t.Logf("WaitForActiveTasksDecrease: error getting node stats: %v", err)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
@@ -257,14 +295,14 @@ func WaitForActiveTasksDecrease(
|
||||
if stat.ID == nodeID {
|
||||
t.Logf(
|
||||
"WaitForActiveTasksDecrease: current active tasks = %d (initial = %d)",
|
||||
stat.ActiveTasks,
|
||||
stat.ActiveBackups,
|
||||
initialCount,
|
||||
)
|
||||
if stat.ActiveTasks < initialCount {
|
||||
if stat.ActiveBackups < initialCount {
|
||||
t.Logf(
|
||||
"WaitForActiveTasksDecrease: active tasks decreased from %d to %d",
|
||||
initialCount,
|
||||
stat.ActiveTasks,
|
||||
stat.ActiveBackups,
|
||||
)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -1,17 +1,37 @@
|
||||
package common
|
||||
|
||||
import backups_config "databasus-backend/internal/features/backups/config"
|
||||
import (
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"errors"
|
||||
|
||||
type BackupType string
|
||||
|
||||
const (
|
||||
BackupTypeDefault BackupType = "DEFAULT" // For MySQL, MongoDB, PostgreSQL legacy (-Fc)
|
||||
BackupTypeDirectory BackupType = "DIRECTORY" // PostgreSQL directory type (-Fd)
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type BackupMetadata struct {
|
||||
EncryptionSalt *string
|
||||
EncryptionIV *string
|
||||
Encryption backups_config.BackupEncryption
|
||||
Type BackupType
|
||||
BackupID uuid.UUID `json:"backupId"`
|
||||
EncryptionSalt *string `json:"encryptionSalt"`
|
||||
EncryptionIV *string `json:"encryptionIV"`
|
||||
Encryption backups_config.BackupEncryption `json:"encryption"`
|
||||
}
|
||||
|
||||
func (m *BackupMetadata) Validate() error {
|
||||
if m.BackupID == uuid.Nil {
|
||||
return errors.New("backup ID is required")
|
||||
}
|
||||
|
||||
if m.Encryption == "" {
|
||||
return errors.New("encryption is required")
|
||||
}
|
||||
|
||||
if m.Encryption == backups_config.BackupEncryptionEncrypted {
|
||||
if m.EncryptionSalt == nil {
|
||||
return errors.New("encryption salt is required when encryption is enabled")
|
||||
}
|
||||
|
||||
if m.EncryptionIV == nil {
|
||||
return errors.New("encryption IV is required when encryption is enabled")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
"databasus-backend/internal/features/databases"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
files_utils "databasus-backend/internal/util/files"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -304,7 +305,6 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
_, err = io.Copy(ctx.Writer, rateLimitedReader)
|
||||
if err != nil {
|
||||
fmt.Printf("Error streaming file: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.backupService.WriteAuditLogForDownload(downloadToken.UserID, backup, database)
|
||||
@@ -322,7 +322,7 @@ func (c *BackupController) generateBackupFilename(
|
||||
timestamp := backup.CreatedAt.Format("2006-01-02_15-04-05")
|
||||
|
||||
// Sanitize database name for filename (replace spaces and special chars)
|
||||
safeName := sanitizeFilename(database.Name)
|
||||
safeName := files_utils.SanitizeFilename(database.Name)
|
||||
|
||||
// Determine extension based on database type
|
||||
extension := c.getBackupExtension(database.Type)
|
||||
@@ -346,33 +346,6 @@ func (c *BackupController) getBackupExtension(
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeFilename(name string) string {
|
||||
// Replace characters that are invalid in filenames
|
||||
replacer := map[rune]rune{
|
||||
' ': '_',
|
||||
'/': '-',
|
||||
'\\': '-',
|
||||
':': '-',
|
||||
'*': '-',
|
||||
'?': '-',
|
||||
'"': '-',
|
||||
'<': '-',
|
||||
'>': '-',
|
||||
'|': '-',
|
||||
}
|
||||
|
||||
result := make([]rune, 0, len(name))
|
||||
for _, char := range name {
|
||||
if replacement, exists := replacer[char]; exists {
|
||||
result = append(result, replacement)
|
||||
} else {
|
||||
result = append(result, char)
|
||||
}
|
||||
}
|
||||
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func (c *BackupController) startDownloadHeartbeat(ctx context.Context, userID uuid.UUID) {
|
||||
ticker := time.NewTicker(backups_download.GetDownloadHeartbeatInterval())
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -18,6 +20,8 @@ import (
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
backups_common "databasus-backend/internal/features/backups/backups/common"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
@@ -32,6 +36,7 @@ import (
|
||||
workspaces_models "databasus-backend/internal/features/workspaces/models"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
files_utils "databasus-backend/internal/util/files"
|
||||
test_utils "databasus-backend/internal/util/testing"
|
||||
"databasus-backend/internal/util/tools"
|
||||
)
|
||||
@@ -80,7 +85,7 @@ func Test_GetBackups_PermissionsEnforced(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, _ := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
database, _, storage := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
@@ -122,6 +127,12 @@ func Test_GetBackups_PermissionsEnforced(t *testing.T) {
|
||||
} else {
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -218,6 +229,10 @@ func Test_CreateBackup_PermissionsEnforced(t *testing.T) {
|
||||
} else {
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -261,6 +276,10 @@ func Test_CreateBackup_AuditLogWritten(t *testing.T) {
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Audit log for backup creation not found")
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_DeleteBackup_PermissionsEnforced(t *testing.T) {
|
||||
@@ -314,7 +333,7 @@ func Test_DeleteBackup_PermissionsEnforced(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
@@ -358,6 +377,12 @@ func Test_DeleteBackup_PermissionsEnforced(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, len(response.Backups))
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -367,7 +392,7 @@ func Test_DeleteBackup_AuditLogWritten(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
@@ -398,6 +423,12 @@ func Test_DeleteBackup_AuditLogWritten(t *testing.T) {
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Audit log for backup deletion not found")
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_GenerateDownloadToken_PermissionsEnforced(t *testing.T) {
|
||||
@@ -444,7 +475,7 @@ func Test_GenerateDownloadToken_PermissionsEnforced(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
@@ -488,6 +519,12 @@ func Test_GenerateDownloadToken_PermissionsEnforced(t *testing.T) {
|
||||
} else {
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -497,7 +534,7 @@ func Test_DownloadBackup_WithValidToken_Success(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
// Generate download token
|
||||
var tokenResponse backups_download.GenerateDownloadTokenResponse
|
||||
@@ -524,6 +561,12 @@ func Test_DownloadBackup_WithValidToken_Success(t *testing.T) {
|
||||
contentDisposition := testResp.Headers.Get("Content-Disposition")
|
||||
assert.Contains(t, contentDisposition, "attachment")
|
||||
assert.Contains(t, contentDisposition, tokenResponse.Filename)
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_DownloadBackup_WithoutToken_Unauthorized(t *testing.T) {
|
||||
@@ -531,7 +574,7 @@ func Test_DownloadBackup_WithoutToken_Unauthorized(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
// Try to download without token
|
||||
testResp := test_utils.MakeGetRequest(
|
||||
@@ -543,6 +586,12 @@ func Test_DownloadBackup_WithoutToken_Unauthorized(t *testing.T) {
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp.Body), "download token is required")
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_DownloadBackup_WithInvalidToken_Unauthorized(t *testing.T) {
|
||||
@@ -550,7 +599,7 @@ func Test_DownloadBackup_WithInvalidToken_Unauthorized(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
// Try to download with invalid token
|
||||
testResp := test_utils.MakeGetRequest(
|
||||
@@ -562,6 +611,12 @@ func Test_DownloadBackup_WithInvalidToken_Unauthorized(t *testing.T) {
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp.Body), "invalid or expired download token")
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_DownloadBackup_WithExpiredToken_Unauthorized(t *testing.T) {
|
||||
@@ -569,7 +624,7 @@ func Test_DownloadBackup_WithExpiredToken_Unauthorized(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
// Get user for token generation
|
||||
userService := users_services.GetUserService()
|
||||
@@ -611,6 +666,12 @@ func Test_DownloadBackup_WithExpiredToken_Unauthorized(t *testing.T) {
|
||||
}
|
||||
}
|
||||
assert.False(t, found, "Audit log should NOT be created for failed download with expired token")
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_DownloadBackup_TokenUsedOnce_CannotReuseToken(t *testing.T) {
|
||||
@@ -618,7 +679,7 @@ func Test_DownloadBackup_TokenUsedOnce_CannotReuseToken(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
// Generate download token
|
||||
var tokenResponse backups_download.GenerateDownloadTokenResponse
|
||||
@@ -651,6 +712,12 @@ func Test_DownloadBackup_TokenUsedOnce_CannotReuseToken(t *testing.T) {
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp.Body), "invalid or expired download token")
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_DownloadBackup_WithDifferentBackupToken_Unauthorized(t *testing.T) {
|
||||
@@ -705,6 +772,13 @@ func Test_DownloadBackup_WithDifferentBackupToken_Unauthorized(t *testing.T) {
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp.Body), "invalid or expired download token")
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database1)
|
||||
databases.RemoveTestDatabase(database2)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_DownloadBackup_AuditLogWritten(t *testing.T) {
|
||||
@@ -712,7 +786,7 @@ func Test_DownloadBackup_AuditLogWritten(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
// Generate download token
|
||||
var tokenResponse backups_download.GenerateDownloadTokenResponse
|
||||
@@ -756,6 +830,12 @@ func Test_DownloadBackup_AuditLogWritten(t *testing.T) {
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Audit log for backup download not found")
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_DownloadBackup_ProperFilenameForPostgreSQL(t *testing.T) {
|
||||
@@ -856,6 +936,12 @@ func Test_DownloadBackup_ProperFilenameForPostgreSQL(t *testing.T) {
|
||||
contentDisposition,
|
||||
"Filename should contain timestamp",
|
||||
)
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -875,7 +961,7 @@ func Test_SanitizeFilename(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
result := sanitizeFilename(tt.input)
|
||||
result := files_utils.SanitizeFilename(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
@@ -948,6 +1034,12 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
|
||||
}
|
||||
}
|
||||
assert.True(t, foundCancelLog, "Cancel audit log should be created")
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_ConcurrentDownloadPrevention(t *testing.T) {
|
||||
@@ -955,7 +1047,7 @@ func Test_ConcurrentDownloadPrevention(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
var token1Response backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
@@ -1003,6 +1095,12 @@ func Test_ConcurrentDownloadPrevention(t *testing.T) {
|
||||
if !service.IsDownloadInProgress(owner.UserID) {
|
||||
t.Log("Warning: First download completed before we could test concurrency")
|
||||
<-downloadComplete
|
||||
|
||||
// Cleanup before early return
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1049,6 +1147,12 @@ func Test_ConcurrentDownloadPrevention(t *testing.T) {
|
||||
t.Log(
|
||||
"Successfully prevented concurrent downloads and allowed subsequent downloads after completion",
|
||||
)
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_GenerateDownloadToken_BlockedWhenDownloadInProgress(t *testing.T) {
|
||||
@@ -1056,7 +1160,7 @@ func Test_GenerateDownloadToken_BlockedWhenDownloadInProgress(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
var token1Response backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
@@ -1092,6 +1196,12 @@ func Test_GenerateDownloadToken_BlockedWhenDownloadInProgress(t *testing.T) {
|
||||
if !service.IsDownloadInProgress(owner.UserID) {
|
||||
t.Log("Warning: First download completed before we could test token generation blocking")
|
||||
<-downloadComplete
|
||||
|
||||
// Cleanup before early return
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1131,6 +1241,92 @@ func Test_GenerateDownloadToken_BlockedWhenDownloadInProgress(t *testing.T) {
|
||||
t.Log(
|
||||
"Successfully blocked token generation during download and allowed generation after completion",
|
||||
)
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_MakeBackup_VerifyBackupAndMetadataFilesExistInStorage(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, _, storage := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
backuperNode := backuping.CreateTestBackuperNode()
|
||||
backuperCancel := backuping.StartBackuperNodeForTest(t, backuperNode)
|
||||
defer backuping.StopBackuperNodeForTest(t, backuperCancel, backuperNode)
|
||||
|
||||
scheduler := backuping.CreateTestScheduler()
|
||||
schedulerCancel := backuping.StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
backupRepo := &backups_core.BackupRepository{}
|
||||
initialBackups, err := backupRepo.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
request := MakeBackupRequest{DatabaseID: database.ID}
|
||||
test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backups",
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusOK,
|
||||
)
|
||||
|
||||
backuping.WaitForBackupCompletion(t, database.ID, len(initialBackups), 30*time.Second)
|
||||
|
||||
backups, err := backupRepo.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(t, len(backups), len(initialBackups))
|
||||
|
||||
backup := backups[0]
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
|
||||
storageService := storages.GetStorageService()
|
||||
backupStorage, err := storageService.GetStorageByID(backup.StorageID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
|
||||
backupFile, err := backupStorage.GetFile(encryptor, backup.FileName)
|
||||
assert.NoError(t, err)
|
||||
backupFile.Close()
|
||||
|
||||
metadataFile, err := backupStorage.GetFile(encryptor, backup.FileName+".metadata")
|
||||
assert.NoError(t, err)
|
||||
|
||||
metadataContent, err := io.ReadAll(metadataFile)
|
||||
assert.NoError(t, err)
|
||||
metadataFile.Close()
|
||||
|
||||
var storageMetadata backups_common.BackupMetadata
|
||||
err = json.Unmarshal(metadataContent, &storageMetadata)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, backup.ID, storageMetadata.BackupID)
|
||||
|
||||
if backup.EncryptionSalt != nil && storageMetadata.EncryptionSalt != nil {
|
||||
assert.Equal(t, *backup.EncryptionSalt, *storageMetadata.EncryptionSalt)
|
||||
}
|
||||
|
||||
if backup.EncryptionIV != nil && storageMetadata.EncryptionIV != nil {
|
||||
assert.Equal(t, *backup.EncryptionIV, *storageMetadata.EncryptionIV)
|
||||
}
|
||||
|
||||
assert.Equal(t, backup.Encryption, storageMetadata.Encryption)
|
||||
|
||||
err = backupRepo.DeleteByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
@@ -1156,7 +1352,7 @@ func createTestDatabase(
|
||||
Type: databases.DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
@@ -1222,7 +1418,7 @@ func createTestDatabaseWithBackups(
|
||||
workspace *workspaces_models.Workspace,
|
||||
owner *users_dto.SignInResponseDTO,
|
||||
router *gin.Engine,
|
||||
) (*databases.Database, *backups_core.Backup) {
|
||||
) (*databases.Database, *backups_core.Backup, *storages.Storage) {
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
@@ -1242,7 +1438,7 @@ func createTestDatabaseWithBackups(
|
||||
|
||||
backup := createTestBackup(database, owner)
|
||||
|
||||
return database, backup
|
||||
return database, backup, storage
|
||||
}
|
||||
|
||||
func createTestBackup(
|
||||
@@ -1255,11 +1451,24 @@ func createTestBackup(
|
||||
panic(err)
|
||||
}
|
||||
|
||||
storages, err := storages.GetStorageService().GetStorages(user, *database.WorkspaceID)
|
||||
if err != nil || len(storages) == 0 {
|
||||
loadedStorages, err := storages.GetStorageService().GetStorages(user, *database.WorkspaceID)
|
||||
if err != nil || len(loadedStorages) == 0 {
|
||||
panic("No storage found for workspace")
|
||||
}
|
||||
|
||||
// Filter out system storages
|
||||
var nonSystemStorages []*storages.Storage
|
||||
for _, storage := range loadedStorages {
|
||||
if !storage.IsSystem {
|
||||
nonSystemStorages = append(nonSystemStorages, storage)
|
||||
}
|
||||
}
|
||||
if len(nonSystemStorages) == 0 {
|
||||
panic("No non-system storage found for workspace")
|
||||
}
|
||||
|
||||
storages := nonSystemStorages
|
||||
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
@@ -1283,7 +1492,7 @@ func createTestBackup(
|
||||
context.Background(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
logger,
|
||||
backup.ID,
|
||||
backup.ID.String(),
|
||||
reader,
|
||||
); err != nil {
|
||||
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
|
||||
@@ -1320,7 +1529,7 @@ func Test_BandwidthThrottling_SingleDownload_Uses75Percent(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
database, backup, storage := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
bandwidthManager := backups_download.GetBandwidthManager()
|
||||
initialCount := bandwidthManager.GetActiveDownloadCount()
|
||||
@@ -1370,6 +1579,12 @@ func Test_BandwidthThrottling_SingleDownload_Uses75Percent(t *testing.T) {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
finalCount := bandwidthManager.GetActiveDownloadCount()
|
||||
assert.Equal(t, initialCount, finalCount, "Download should be unregistered after completion")
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_BandwidthThrottling_MultipleDownloads_ShareBandwidth(t *testing.T) {
|
||||
@@ -1489,6 +1704,12 @@ func Test_BandwidthThrottling_MultipleDownloads_ShareBandwidth(t *testing.T) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
finalCount := bandwidthManager.GetActiveDownloadCount()
|
||||
assert.Equal(t, initialCount, finalCount, "All downloads should be unregistered")
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_BandwidthThrottling_DynamicAdjustment(t *testing.T) {
|
||||
@@ -1577,4 +1798,91 @@ func Test_BandwidthThrottling_DynamicAdjustment(t *testing.T) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
finalCount := bandwidthManager.GetActiveDownloadCount()
|
||||
assert.Equal(t, initialCount, finalCount, "All downloads completed and unregistered")
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_DeleteBackup_RemovesBackupAndMetadataFilesFromDisk(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
configService := backups_config.GetBackupConfigService()
|
||||
backupConfig, err := configService.GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorageID = &storage.ID
|
||||
backupConfig.Storage = storage
|
||||
_, err = configService.SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backuperNode := backuping.CreateTestBackuperNode()
|
||||
backuperCancel := backuping.StartBackuperNodeForTest(t, backuperNode)
|
||||
defer backuping.StopBackuperNodeForTest(t, backuperCancel, backuperNode)
|
||||
|
||||
scheduler := backuping.CreateTestScheduler()
|
||||
schedulerCancel := backuping.StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
backupRepo := &backups_core.BackupRepository{}
|
||||
initialBackups, err := backupRepo.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
request := MakeBackupRequest{DatabaseID: database.ID}
|
||||
test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backups",
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusOK,
|
||||
)
|
||||
|
||||
backuping.WaitForBackupCompletion(t, database.ID, len(initialBackups), 30*time.Second)
|
||||
|
||||
backups, err := backupRepo.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(t, len(backups), len(initialBackups))
|
||||
|
||||
backup := backups[0]
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
|
||||
dataFolder := config.GetEnv().DataFolder
|
||||
backupFilePath := filepath.Join(dataFolder, backup.FileName)
|
||||
metadataFilePath := filepath.Join(dataFolder, backup.FileName+".metadata")
|
||||
|
||||
_, err = os.Stat(backupFilePath)
|
||||
assert.NoError(t, err, "backup file should exist on disk before deletion")
|
||||
|
||||
_, err = os.Stat(metadataFilePath)
|
||||
assert.NoError(t, err, "metadata file should exist on disk before deletion")
|
||||
|
||||
test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusNoContent,
|
||||
)
|
||||
|
||||
_, err = os.Stat(backupFilePath)
|
||||
assert.True(t, os.IsNotExist(err), "backup file should be removed from disk after deletion")
|
||||
|
||||
_, err = os.Stat(metadataFilePath)
|
||||
assert.True(t, os.IsNotExist(err), "metadata file should be removed from disk after deletion")
|
||||
}
|
||||
|
||||
@@ -8,8 +8,6 @@ import (
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type NotificationSender interface {
|
||||
@@ -23,7 +21,7 @@ type NotificationSender interface {
|
||||
type CreateBackupUsecase interface {
|
||||
Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
|
||||
@@ -8,13 +8,15 @@ import (
|
||||
)
|
||||
|
||||
type Backup struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
|
||||
FileName string `json:"fileName" gorm:"column:file_name;type:text;not null"`
|
||||
|
||||
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;not null"`
|
||||
StorageID uuid.UUID `json:"storageId" gorm:"column:storage_id;type:uuid;not null"`
|
||||
|
||||
Status BackupStatus `json:"status" gorm:"column:status;not null"`
|
||||
FailMessage *string `json:"failMessage" gorm:"column:fail_message"`
|
||||
IsSkipRetry bool `json:"isSkipRetry" gorm:"column:is_skip_retry;type:boolean;not null"`
|
||||
|
||||
BackupSizeMb float64 `json:"backupSizeMb" gorm:"column:backup_size_mb;default:0"`
|
||||
|
||||
|
||||
@@ -212,3 +212,36 @@ func (r *BackupRepository) CountByDatabaseID(databaseID uuid.UUID) (int64, error
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *BackupRepository) GetTotalSizeByDatabase(databaseID uuid.UUID) (float64, error) {
|
||||
var totalSize float64
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Model(&Backup{}).
|
||||
Select("COALESCE(SUM(backup_size_mb), 0)").
|
||||
Where("database_id = ? AND status != ?", databaseID, BackupStatusInProgress).
|
||||
Scan(&totalSize).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return totalSize, nil
|
||||
}
|
||||
|
||||
func (r *BackupRepository) FindOldestByDatabaseExcludingInProgress(
|
||||
databaseID uuid.UUID,
|
||||
limit int,
|
||||
) ([]*Backup, error) {
|
||||
var backups []*Backup
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Where("database_id = ? AND status != ?", databaseID, BackupStatusInProgress).
|
||||
Order("created_at ASC").
|
||||
Limit(limit).
|
||||
Find(&backups).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return backups, nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
@@ -22,22 +25,23 @@ var backupRepository = &backups_core.BackupRepository{}
|
||||
var taskCancelManager = task_cancellation.GetTaskCancelManager()
|
||||
|
||||
var backupService = &BackupService{
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
backupRepository: backupRepository,
|
||||
notifierService: notifiers.GetNotifierService(),
|
||||
notificationSender: notifiers.GetNotifierService(),
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
secretKeyService: encryption_secrets.GetSecretKeyService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
createBackupUseCase: usecases.GetCreateBackupUsecase(),
|
||||
logger: logger.GetLogger(),
|
||||
backupRemoveListeners: []backups_core.BackupRemoveListener{},
|
||||
workspaceService: workspaces_services.GetWorkspaceService(),
|
||||
auditLogService: audit_logs.GetAuditLogService(),
|
||||
taskCancelManager: taskCancelManager,
|
||||
downloadTokenService: backups_download.GetDownloadTokenService(),
|
||||
backupSchedulerService: backuping.GetBackupsScheduler(),
|
||||
databases.GetDatabaseService(),
|
||||
storages.GetStorageService(),
|
||||
backupRepository,
|
||||
notifiers.GetNotifierService(),
|
||||
notifiers.GetNotifierService(),
|
||||
backups_config.GetBackupConfigService(),
|
||||
encryption_secrets.GetSecretKeyService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
logger.GetLogger(),
|
||||
[]backups_core.BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
audit_logs.GetAuditLogService(),
|
||||
taskCancelManager,
|
||||
backups_download.GetDownloadTokenService(),
|
||||
backuping.GetBackupsScheduler(),
|
||||
backuping.GetBackupCleaner(),
|
||||
}
|
||||
|
||||
var backupController = &BackupController{
|
||||
@@ -52,11 +56,26 @@ func GetBackupController() *BackupController {
|
||||
return backupController
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
backups_config.
|
||||
GetBackupConfigService().
|
||||
SetDatabaseStorageChangeListener(backupService)
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
databases.GetDatabaseService().AddDbRemoveListener(backupService)
|
||||
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
|
||||
func SetupDependencies() {
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
backups_config.
|
||||
GetBackupConfigService().
|
||||
SetDatabaseStorageChangeListener(backupService)
|
||||
|
||||
databases.GetDatabaseService().AddDbRemoveListener(backupService)
|
||||
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,33 +2,49 @@ package backups_download
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DownloadTokenBackgroundService struct {
|
||||
downloadTokenService *DownloadTokenService
|
||||
logger *slog.Logger
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *DownloadTokenBackgroundService) Run(ctx context.Context) {
|
||||
s.logger.Info("Starting download token cleanup background service")
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
s.logger.Info("Starting download token cleanup background service")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
|
||||
s.logger.Error("Failed to clean expired download tokens", "error", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
|
||||
s.logger.Error("Failed to clean expired download tokens", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/logger"
|
||||
@@ -30,8 +33,10 @@ func init() {
|
||||
}
|
||||
|
||||
downloadTokenBackgroundService = &DownloadTokenBackgroundService{
|
||||
downloadTokenService,
|
||||
logger.GetLogger(),
|
||||
downloadTokenService: downloadTokenService,
|
||||
logger: logger.GetLogger(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
users_models "databasus-backend/internal/features/users/models"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
files_utils "databasus-backend/internal/util/files"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -46,6 +47,7 @@ type BackupService struct {
|
||||
taskCancelManager *task_cancellation.TaskCancelManager
|
||||
downloadTokenService *backups_download.DownloadTokenService
|
||||
backupSchedulerService *backuping.BackupsScheduler
|
||||
backupCleaner *backuping.BackupCleaner
|
||||
}
|
||||
|
||||
func (s *BackupService) AddBackupRemoveListener(listener backups_core.BackupRemoveListener) {
|
||||
@@ -91,7 +93,7 @@ func (s *BackupService) MakeBackupWithAuth(
|
||||
return errors.New("insufficient permissions to create backup for this database")
|
||||
}
|
||||
|
||||
s.backupSchedulerService.StartBackup(databaseID, true)
|
||||
s.backupSchedulerService.StartBackup(database, true)
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Backup manually initiated for database: %s", database.Name),
|
||||
@@ -180,16 +182,12 @@ func (s *BackupService) DeleteBackup(
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
"Backup deleted for database: %s (ID: %s)",
|
||||
database.Name,
|
||||
backupID.String(),
|
||||
),
|
||||
fmt.Sprintf("Backup deleted for database: %s", database.Name),
|
||||
&user.ID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
|
||||
return s.deleteBackup(backup)
|
||||
return s.backupCleaner.DeleteBackup(backup)
|
||||
}
|
||||
|
||||
func (s *BackupService) GetBackup(backupID uuid.UUID) (*backups_core.Backup, error) {
|
||||
@@ -231,11 +229,7 @@ func (s *BackupService) CancelBackup(
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
"Backup cancelled for database: %s (ID: %s)",
|
||||
database.Name,
|
||||
backupID.String(),
|
||||
),
|
||||
fmt.Sprintf("Backup cancelled for database: %s", database.Name),
|
||||
&user.ID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
@@ -275,11 +269,7 @@ func (s *BackupService) GetBackupFile(
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
"Backup file downloaded for database: %s (ID: %s)",
|
||||
database.Name,
|
||||
backupID.String(),
|
||||
),
|
||||
fmt.Sprintf("Backup file downloaded for database: %s", database.Name),
|
||||
&user.ID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
@@ -292,29 +282,6 @@ func (s *BackupService) GetBackupFile(
|
||||
return reader, backup, database, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) deleteBackup(backup *backups_core.Backup) error {
|
||||
for _, listener := range s.backupRemoveListeners {
|
||||
if err := listener.OnBeforeBackupRemove(backup); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
storage, err := s.storageService.GetStorageByID(backup.StorageID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = storage.DeleteFile(s.fieldEncryptor, backup.ID)
|
||||
if err != nil {
|
||||
// we do not return error here, because sometimes clean up performed
|
||||
// before unavailable storage removal or change - therefore we should
|
||||
// proceed even in case of error
|
||||
s.logger.Error("Failed to delete backup file", "error", err)
|
||||
}
|
||||
|
||||
return s.backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
|
||||
dbBackupsInProgress, err := s.backupRepository.FindByDatabaseIdAndStatus(
|
||||
databaseID,
|
||||
@@ -336,7 +303,7 @@ func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
|
||||
}
|
||||
|
||||
for _, dbBackup := range dbBackups {
|
||||
err := s.deleteBackup(dbBackup)
|
||||
err := s.backupCleaner.DeleteBackup(dbBackup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -358,7 +325,7 @@ func (s *BackupService) getBackupReader(backupID uuid.UUID) (io.ReadCloser, erro
|
||||
return nil, fmt.Errorf("failed to get storage: %w", err)
|
||||
}
|
||||
|
||||
fileReader, err := storage.GetFile(s.fieldEncryptor, backup.ID)
|
||||
fileReader, err := storage.GetFile(s.fieldEncryptor, backup.FileName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get backup file: %w", err)
|
||||
}
|
||||
@@ -512,11 +479,7 @@ func (s *BackupService) WriteAuditLogForDownload(
|
||||
database *databases.Database,
|
||||
) {
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
"Backup file downloaded for database: %s (ID: %s)",
|
||||
database.Name,
|
||||
backup.ID.String(),
|
||||
),
|
||||
fmt.Sprintf("Backup file downloaded for database: %s", database.Name),
|
||||
&userID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
@@ -543,7 +506,7 @@ func (s *BackupService) generateBackupFilename(
|
||||
database *databases.Database,
|
||||
) string {
|
||||
timestamp := backup.CreatedAt.Format("2006-01-02_15-04-05")
|
||||
safeName := sanitizeFilename(database.Name)
|
||||
safeName := files_utils.SanitizeFilename(database.Name)
|
||||
extension := s.getBackupExtension(database.Type)
|
||||
return fmt.Sprintf("%s_backup_%s%s", safeName, timestamp, extension)
|
||||
}
|
||||
|
||||
@@ -75,3 +75,23 @@ func WaitForBackupCompletion(
|
||||
|
||||
t.Logf("WaitForBackupCompletion: timeout waiting for backup to complete")
|
||||
}
|
||||
|
||||
// CreateTestBackup creates a simple test backup record for testing purposes
|
||||
func CreateTestBackup(databaseID, storageID uuid.UUID) *backups_core.Backup {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: databaseID,
|
||||
StorageID: storageID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10.5,
|
||||
BackupDurationMs: 1000,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
repo := &backups_core.BackupRepository{}
|
||||
if err := repo.Save(backup); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return backup
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
|
||||
common "databasus-backend/internal/features/backups/backups/common"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
usecases_mariadb "databasus-backend/internal/features/backups/backups/usecases/mariadb"
|
||||
usecases_mongodb "databasus-backend/internal/features/backups/backups/usecases/mongodb"
|
||||
usecases_mysql "databasus-backend/internal/features/backups/backups/usecases/mysql"
|
||||
@@ -12,8 +13,6 @@ import (
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/storages"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type CreateBackupUsecase struct {
|
||||
@@ -25,7 +24,7 @@ type CreateBackupUsecase struct {
|
||||
|
||||
func (uc *CreateBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
@@ -35,7 +34,7 @@ func (uc *CreateBackupUsecase) Execute(
|
||||
case databases.DatabaseTypePostgres:
|
||||
return uc.CreatePostgresqlBackupUsecase.Execute(
|
||||
ctx,
|
||||
backupID,
|
||||
backup,
|
||||
backupConfig,
|
||||
database,
|
||||
storage,
|
||||
@@ -45,7 +44,7 @@ func (uc *CreateBackupUsecase) Execute(
|
||||
case databases.DatabaseTypeMysql:
|
||||
return uc.CreateMysqlBackupUsecase.Execute(
|
||||
ctx,
|
||||
backupID,
|
||||
backup,
|
||||
backupConfig,
|
||||
database,
|
||||
storage,
|
||||
@@ -55,7 +54,7 @@ func (uc *CreateBackupUsecase) Execute(
|
||||
case databases.DatabaseTypeMariadb:
|
||||
return uc.CreateMariadbBackupUsecase.Execute(
|
||||
ctx,
|
||||
backupID,
|
||||
backup,
|
||||
backupConfig,
|
||||
database,
|
||||
storage,
|
||||
@@ -65,7 +64,7 @@ func (uc *CreateBackupUsecase) Execute(
|
||||
case databases.DatabaseTypeMongodb:
|
||||
return uc.CreateMongodbBackupUsecase.Execute(
|
||||
ctx,
|
||||
backupID,
|
||||
backup,
|
||||
backupConfig,
|
||||
database,
|
||||
storage,
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
common "databasus-backend/internal/features/backups/backups/common"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -52,7 +53,7 @@ type writeResult struct {
|
||||
|
||||
func (uc *CreateMariadbBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
db *databases.Database,
|
||||
storage *storages.Storage,
|
||||
@@ -82,7 +83,7 @@ func (uc *CreateMariadbBackupUsecase) Execute(
|
||||
|
||||
return uc.streamToStorage(
|
||||
ctx,
|
||||
backupID,
|
||||
backup,
|
||||
backupConfig,
|
||||
tools.GetMariadbExecutable(
|
||||
tools.MariadbExecutableMariadbDump,
|
||||
@@ -108,13 +109,15 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpArgs(
|
||||
"--single-transaction",
|
||||
"--routines",
|
||||
"--quick",
|
||||
"--skip-extended-insert",
|
||||
"--verbose",
|
||||
}
|
||||
|
||||
if mdb.HasPrivilege("TRIGGER") {
|
||||
args = append(args, "--triggers")
|
||||
}
|
||||
if mdb.HasPrivilege("EVENT") {
|
||||
|
||||
if mdb.HasPrivilege("EVENT") && !mdb.IsExcludeEvents {
|
||||
args = append(args, "--events")
|
||||
}
|
||||
|
||||
@@ -134,7 +137,7 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpArgs(
|
||||
|
||||
func (uc *CreateMariadbBackupUsecase) streamToStorage(
|
||||
parentCtx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
mariadbBin string,
|
||||
args []string,
|
||||
@@ -185,7 +188,7 @@ func (uc *CreateMariadbBackupUsecase) streamToStorage(
|
||||
storageReader, storageWriter := io.Pipe()
|
||||
|
||||
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
|
||||
backupID,
|
||||
backup.ID,
|
||||
backupConfig,
|
||||
storageWriter,
|
||||
)
|
||||
@@ -202,7 +205,13 @@ func (uc *CreateMariadbBackupUsecase) streamToStorage(
|
||||
|
||||
saveErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
saveErr := storage.SaveFile(ctx, uc.fieldEncryptor, uc.logger, backupID, storageReader)
|
||||
saveErr := storage.SaveFile(
|
||||
ctx,
|
||||
uc.fieldEncryptor,
|
||||
uc.logger,
|
||||
backup.FileName,
|
||||
storageReader,
|
||||
)
|
||||
saveErrCh <- saveErr
|
||||
}()
|
||||
|
||||
@@ -418,7 +427,9 @@ func (uc *CreateMariadbBackupUsecase) setupBackupEncryption(
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
storageWriter io.WriteCloser,
|
||||
) (io.Writer, *backup_encryption.EncryptionWriter, common.BackupMetadata, error) {
|
||||
metadata := common.BackupMetadata{}
|
||||
metadata := common.BackupMetadata{
|
||||
BackupID: backupID,
|
||||
}
|
||||
|
||||
if backupConfig.Encryption != backups_config.BackupEncryptionEncrypted {
|
||||
metadata.Encryption = backups_config.BackupEncryptionNone
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
common "databasus-backend/internal/features/backups/backups/common"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -46,7 +47,7 @@ type writeResult struct {
|
||||
|
||||
func (uc *CreateMongodbBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
db *databases.Database,
|
||||
storage *storages.Storage,
|
||||
@@ -76,7 +77,7 @@ func (uc *CreateMongodbBackupUsecase) Execute(
|
||||
|
||||
return uc.streamToStorage(
|
||||
ctx,
|
||||
backupID,
|
||||
backup,
|
||||
backupConfig,
|
||||
tools.GetMongodbExecutable(
|
||||
tools.MongodbExecutableMongodump,
|
||||
@@ -114,7 +115,7 @@ func (uc *CreateMongodbBackupUsecase) buildMongodumpArgs(
|
||||
|
||||
func (uc *CreateMongodbBackupUsecase) streamToStorage(
|
||||
parentCtx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
mongodumpBin string,
|
||||
args []string,
|
||||
@@ -163,7 +164,7 @@ func (uc *CreateMongodbBackupUsecase) streamToStorage(
|
||||
storageReader, storageWriter := io.Pipe()
|
||||
|
||||
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
|
||||
backupID,
|
||||
backup.ID,
|
||||
backupConfig,
|
||||
storageWriter,
|
||||
)
|
||||
@@ -175,7 +176,13 @@ func (uc *CreateMongodbBackupUsecase) streamToStorage(
|
||||
|
||||
saveErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
saveErr := storage.SaveFile(ctx, uc.fieldEncryptor, uc.logger, backupID, storageReader)
|
||||
saveErr := storage.SaveFile(
|
||||
ctx,
|
||||
uc.fieldEncryptor,
|
||||
uc.logger,
|
||||
backup.FileName,
|
||||
storageReader,
|
||||
)
|
||||
saveErrCh <- saveErr
|
||||
}()
|
||||
|
||||
@@ -262,6 +269,7 @@ func (uc *CreateMongodbBackupUsecase) setupBackupEncryption(
|
||||
storageWriter io.WriteCloser,
|
||||
) (io.Writer, *backup_encryption.EncryptionWriter, common.BackupMetadata, error) {
|
||||
backupMetadata := common.BackupMetadata{
|
||||
BackupID: backupID,
|
||||
Encryption: backups_config.BackupEncryptionNone,
|
||||
}
|
||||
|
||||
@@ -298,6 +306,7 @@ func (uc *CreateMongodbBackupUsecase) setupBackupEncryption(
|
||||
saltBase64 := base64.StdEncoding.EncodeToString(salt)
|
||||
nonceBase64 := base64.StdEncoding.EncodeToString(nonce)
|
||||
|
||||
backupMetadata.BackupID = backupID
|
||||
backupMetadata.Encryption = backups_config.BackupEncryptionEncrypted
|
||||
backupMetadata.EncryptionSalt = &saltBase64
|
||||
backupMetadata.EncryptionIV = &nonceBase64
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
common "databasus-backend/internal/features/backups/backups/common"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -52,7 +53,7 @@ type writeResult struct {
|
||||
|
||||
func (uc *CreateMysqlBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
db *databases.Database,
|
||||
storage *storages.Storage,
|
||||
@@ -82,7 +83,7 @@ func (uc *CreateMysqlBackupUsecase) Execute(
|
||||
|
||||
return uc.streamToStorage(
|
||||
ctx,
|
||||
backupID,
|
||||
backup,
|
||||
backupConfig,
|
||||
tools.GetMysqlExecutable(
|
||||
my.Version,
|
||||
@@ -107,6 +108,7 @@ func (uc *CreateMysqlBackupUsecase) buildMysqldumpArgs(my *mysqltypes.MysqlDatab
|
||||
"--routines",
|
||||
"--set-gtid-purged=OFF",
|
||||
"--quick",
|
||||
"--skip-extended-insert",
|
||||
"--verbose",
|
||||
}
|
||||
|
||||
@@ -148,7 +150,7 @@ func (uc *CreateMysqlBackupUsecase) getNetworkCompressionArgs(version tools.Mysq
|
||||
|
||||
func (uc *CreateMysqlBackupUsecase) streamToStorage(
|
||||
parentCtx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
mysqlBin string,
|
||||
args []string,
|
||||
@@ -199,7 +201,7 @@ func (uc *CreateMysqlBackupUsecase) streamToStorage(
|
||||
storageReader, storageWriter := io.Pipe()
|
||||
|
||||
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
|
||||
backupID,
|
||||
backup.ID,
|
||||
backupConfig,
|
||||
storageWriter,
|
||||
)
|
||||
@@ -216,7 +218,13 @@ func (uc *CreateMysqlBackupUsecase) streamToStorage(
|
||||
|
||||
saveErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
saveErr := storage.SaveFile(ctx, uc.fieldEncryptor, uc.logger, backupID, storageReader)
|
||||
saveErr := storage.SaveFile(
|
||||
ctx,
|
||||
uc.fieldEncryptor,
|
||||
uc.logger,
|
||||
backup.FileName,
|
||||
storageReader,
|
||||
)
|
||||
saveErrCh <- saveErr
|
||||
}()
|
||||
|
||||
@@ -430,7 +438,9 @@ func (uc *CreateMysqlBackupUsecase) setupBackupEncryption(
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
storageWriter io.WriteCloser,
|
||||
) (io.Writer, *backup_encryption.EncryptionWriter, common.BackupMetadata, error) {
|
||||
metadata := common.BackupMetadata{}
|
||||
metadata := common.BackupMetadata{
|
||||
BackupID: backupID,
|
||||
}
|
||||
|
||||
if backupConfig.Encryption != backups_config.BackupEncryptionEncrypted {
|
||||
metadata.Encryption = backups_config.BackupEncryptionNone
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
common "databasus-backend/internal/features/backups/backups/common"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -53,7 +54,7 @@ type writeResult struct {
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
db *databases.Database,
|
||||
storage *storages.Storage,
|
||||
@@ -88,7 +89,7 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
|
||||
|
||||
return uc.streamToStorage(
|
||||
ctx,
|
||||
backupID,
|
||||
backup,
|
||||
backupConfig,
|
||||
tools.GetPostgresqlExecutable(
|
||||
pg.Version,
|
||||
@@ -107,7 +108,7 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
|
||||
// streamToStorage streams pg_dump output directly to storage
|
||||
func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
|
||||
parentCtx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
pgBin string,
|
||||
args []string,
|
||||
@@ -166,7 +167,7 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
|
||||
storageReader, storageWriter := io.Pipe()
|
||||
|
||||
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
|
||||
backupID,
|
||||
backup.ID,
|
||||
backupConfig,
|
||||
storageWriter,
|
||||
)
|
||||
@@ -181,7 +182,13 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
|
||||
// Start streaming into storage in its own goroutine
|
||||
saveErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
saveErr := storage.SaveFile(ctx, uc.fieldEncryptor, uc.logger, backupID, storageReader)
|
||||
saveErr := storage.SaveFile(
|
||||
ctx,
|
||||
uc.fieldEncryptor,
|
||||
uc.logger,
|
||||
backup.FileName,
|
||||
storageReader,
|
||||
)
|
||||
saveErrCh <- saveErr
|
||||
}()
|
||||
|
||||
@@ -475,7 +482,9 @@ func (uc *CreatePostgresqlBackupUsecase) setupBackupEncryption(
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
storageWriter io.WriteCloser,
|
||||
) (io.Writer, *backup_encryption.EncryptionWriter, common.BackupMetadata, error) {
|
||||
metadata := common.BackupMetadata{}
|
||||
metadata := common.BackupMetadata{
|
||||
BackupID: backupID,
|
||||
}
|
||||
|
||||
if backupConfig.Encryption != backups_config.BackupEncryptionEncrypted {
|
||||
metadata.Encryption = backups_config.BackupEncryptionNone
|
||||
|
||||
@@ -16,6 +16,7 @@ type BackupConfigController struct {
|
||||
|
||||
func (c *BackupConfigController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
router.POST("/backup-configs/save", c.SaveBackupConfig)
|
||||
router.GET("/backup-configs/database/:id/plan", c.GetDatabasePlan)
|
||||
router.GET("/backup-configs/database/:id", c.GetBackupConfigByDbID)
|
||||
router.GET("/backup-configs/storage/:id/is-using", c.IsStorageUsing)
|
||||
router.GET("/backup-configs/storage/:id/databases-count", c.CountDatabasesForStorage)
|
||||
@@ -92,6 +93,39 @@ func (c *BackupConfigController) GetBackupConfigByDbID(ctx *gin.Context) {
|
||||
ctx.JSON(http.StatusOK, backupConfig)
|
||||
}
|
||||
|
||||
// GetDatabasePlan
|
||||
// @Summary Get database plan by database ID
|
||||
// @Description Get the plan limits for a specific database (max backup size, max total size, max storage period)
|
||||
// @Tags backup-configs
|
||||
// @Produce json
|
||||
// @Param id path string true "Database ID"
|
||||
// @Success 200 {object} plans.DatabasePlan
|
||||
// @Failure 400 {object} map[string]string "Invalid database ID"
|
||||
// @Failure 401 {object} map[string]string "User not authenticated"
|
||||
// @Failure 404 {object} map[string]string "Database not found or access denied"
|
||||
// @Router /backup-configs/database/{id}/plan [get]
|
||||
func (c *BackupConfigController) GetDatabasePlan(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
|
||||
return
|
||||
}
|
||||
|
||||
plan, err := c.backupConfigService.GetDatabasePlan(user, id)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusNotFound, gin.H{"error": "database plan not found"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, plan)
|
||||
}
|
||||
|
||||
// IsStorageUsing
|
||||
// @Summary Check if storage is being used
|
||||
// @Description Check if a storage is currently being used by any backup configuration
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -16,11 +17,14 @@ import (
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
plans "databasus-backend/internal/features/plan"
|
||||
"databasus-backend/internal/features/storages"
|
||||
local_storage "databasus-backend/internal/features/storages/models/local"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/storage"
|
||||
"databasus-backend/internal/util/period"
|
||||
test_utils "databasus-backend/internal/util/testing"
|
||||
"databasus-backend/internal/util/tools"
|
||||
@@ -89,6 +93,11 @@ func Test_SaveBackupConfig_PermissionsEnforced(t *testing.T) {
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
@@ -109,9 +118,10 @@ func Test_SaveBackupConfig_PermissionsEnforced(t *testing.T) {
|
||||
|
||||
timeOfDay := "04:00"
|
||||
request := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
@@ -137,7 +147,7 @@ func Test_SaveBackupConfig_PermissionsEnforced(t *testing.T) {
|
||||
if tt.expectSuccess {
|
||||
assert.Equal(t, database.ID, response.DatabaseID)
|
||||
assert.True(t, response.IsBackupsEnabled)
|
||||
assert.Equal(t, period.PeriodWeek, response.StorePeriod)
|
||||
assert.Equal(t, period.PeriodWeek, response.RetentionTimePeriod)
|
||||
} else {
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
@@ -152,13 +162,19 @@ func Test_SaveBackupConfig_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *test
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
request := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
@@ -242,6 +258,11 @@ func Test_GetBackupConfigByDbID_PermissionsEnforced(t *testing.T) {
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
@@ -290,6 +311,11 @@ func Test_GetBackupConfigByDbID_ReturnsDefaultConfigForNewDatabase(t *testing.T)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
var response BackupConfig
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
@@ -300,14 +326,218 @@ func Test_GetBackupConfigByDbID_ReturnsDefaultConfigForNewDatabase(t *testing.T)
|
||||
&response,
|
||||
)
|
||||
|
||||
var plan plans.DatabasePlan
|
||||
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&plan,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, response.DatabaseID)
|
||||
assert.False(t, response.IsBackupsEnabled)
|
||||
assert.Equal(t, period.PeriodWeek, response.StorePeriod)
|
||||
assert.Equal(t, plan.MaxStoragePeriod, response.RetentionTimePeriod)
|
||||
assert.Equal(t, plan.MaxBackupSizeMB, response.MaxBackupSizeMB)
|
||||
assert.Equal(t, plan.MaxBackupsTotalSizeMB, response.MaxBackupsTotalSizeMB)
|
||||
assert.True(t, response.IsRetryIfFailed)
|
||||
assert.Equal(t, 3, response.MaxFailedTriesCount)
|
||||
assert.NotNil(t, response.BackupInterval)
|
||||
}
|
||||
|
||||
func Test_GetDatabasePlan_ForNewDatabase_PlanAlwaysReturned(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
var response plans.DatabasePlan
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, response.DatabaseID)
|
||||
assert.NotNil(t, response.MaxBackupSizeMB)
|
||||
assert.NotNil(t, response.MaxBackupsTotalSizeMB)
|
||||
assert.NotEmpty(t, response.MaxStoragePeriod)
|
||||
}
|
||||
|
||||
func Test_SaveBackupConfig_WhenPlanLimitsAreAdjusted_ValidationEnforced(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Get plan via API (triggers auto-creation)
|
||||
var plan plans.DatabasePlan
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&plan,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, plan.DatabaseID)
|
||||
|
||||
// Adjust plan limits directly in database to fixed restrictive values
|
||||
err := storage.GetDb().Model(&plans.DatabasePlan{}).
|
||||
Where("database_id = ?", database.ID).
|
||||
Updates(map[string]any{
|
||||
"max_backup_size_mb": 100,
|
||||
"max_backups_total_size_mb": 1000,
|
||||
"max_storage_period": period.PeriodMonth,
|
||||
}).Error
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test 1: Try to save backup config with exceeded backup size limit
|
||||
timeOfDay := "04:00"
|
||||
backupConfigExceededSize := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
MaxBackupSizeMB: 200, // Exceeds limit of 100
|
||||
MaxBackupsTotalSizeMB: 800,
|
||||
}
|
||||
|
||||
respExceededSize := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner.Token,
|
||||
backupConfigExceededSize,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
assert.Contains(t, string(respExceededSize.Body), "max backup size exceeds plan limit")
|
||||
|
||||
// Test 2: Try to save backup config with exceeded total size limit
|
||||
backupConfigExceededTotal := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
MaxBackupSizeMB: 50,
|
||||
MaxBackupsTotalSizeMB: 2000, // Exceeds limit of 1000
|
||||
}
|
||||
|
||||
respExceededTotal := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner.Token,
|
||||
backupConfigExceededTotal,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
assert.Contains(t, string(respExceededTotal.Body), "max total backups size exceeds plan limit")
|
||||
|
||||
// Test 3: Try to save backup config with exceeded storage period limit
|
||||
backupConfigExceededPeriod := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodYear, // Exceeds limit of Month
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
MaxBackupSizeMB: 80,
|
||||
MaxBackupsTotalSizeMB: 800,
|
||||
}
|
||||
|
||||
respExceededPeriod := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner.Token,
|
||||
backupConfigExceededPeriod,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
assert.Contains(t, string(respExceededPeriod.Body), "storage period exceeds plan limit")
|
||||
|
||||
// Test 4: Save backup config within all limits - should succeed
|
||||
backupConfigValid := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek, // Within Month limit
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
MaxBackupSizeMB: 80, // Within 100 limit
|
||||
MaxBackupsTotalSizeMB: 800, // Within 1000 limit
|
||||
}
|
||||
|
||||
var responseValid BackupConfig
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner.Token,
|
||||
backupConfigValid,
|
||||
http.StatusOK,
|
||||
&responseValid,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, responseValid.DatabaseID)
|
||||
assert.Equal(t, int64(80), responseValid.MaxBackupSizeMB)
|
||||
assert.Equal(t, int64(800), responseValid.MaxBackupsTotalSizeMB)
|
||||
assert.Equal(t, period.PeriodWeek, responseValid.RetentionTimePeriod)
|
||||
}
|
||||
|
||||
func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -340,6 +570,10 @@ func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
|
||||
)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
defer func() {
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
var testUserToken string
|
||||
if tt.isStorageOwner {
|
||||
testUserToken = storageOwner.Token
|
||||
@@ -372,10 +606,6 @@ func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
|
||||
)
|
||||
assert.Contains(t, string(testResp.Body), "error")
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -387,11 +617,17 @@ func Test_SaveBackupConfig_WithEncryptionNone_ConfigSaved(t *testing.T) {
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
timeOfDay := "04:00"
|
||||
request := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
@@ -426,11 +662,17 @@ func Test_SaveBackupConfig_WithEncryptionEncrypted_ConfigSaved(t *testing.T) {
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
timeOfDay := "04:00"
|
||||
request := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
@@ -536,6 +778,15 @@ func Test_TransferDatabase_PermissionsEnforced(t *testing.T) {
|
||||
|
||||
targetStorage := createTestStorage(targetWorkspace.ID)
|
||||
|
||||
defer func() {
|
||||
// Cleanup in correct order to avoid foreign key violations
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond) // Wait for cascade delete of backup_config
|
||||
storages.RemoveTestStorage(targetStorage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
|
||||
}()
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
@@ -628,6 +879,12 @@ func Test_TransferDatabase_NonMemberInSourceWorkspace_CannotTransfer(t *testing.
|
||||
router,
|
||||
)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
|
||||
}()
|
||||
|
||||
request := TransferDatabaseRequest{
|
||||
TargetWorkspaceID: targetWorkspace.ID,
|
||||
}
|
||||
@@ -668,6 +925,12 @@ func Test_TransferDatabase_NonMemberInTargetWorkspace_CannotTransfer(t *testing.
|
||||
router,
|
||||
)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
|
||||
}()
|
||||
|
||||
request := TransferDatabaseRequest{
|
||||
TargetWorkspaceID: targetWorkspace.ID,
|
||||
}
|
||||
@@ -695,11 +958,19 @@ func Test_TransferDatabase_ToNewStorage_DatabaseTransferd(t *testing.T) {
|
||||
sourceStorage := createTestStorage(sourceWorkspace.ID)
|
||||
targetStorage := createTestStorage(targetWorkspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(200 * time.Millisecond) // Wait for cascading deletes
|
||||
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
|
||||
}()
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfigRequest := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
@@ -774,11 +1045,19 @@ func Test_TransferDatabase_WithExistingStorage_DatabaseAndStorageTransferd(t *te
|
||||
database := createTestDatabaseViaAPI("Test Database", sourceWorkspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(sourceWorkspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(200 * time.Millisecond) // Wait for cascading deletes
|
||||
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
|
||||
}()
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfigRequest := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
@@ -863,11 +1142,20 @@ func Test_TransferDatabase_StorageHasOtherDBs_CannotTransfer(t *testing.T) {
|
||||
)
|
||||
storage := createTestStorage(sourceWorkspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database1)
|
||||
databases.RemoveTestDatabase(database2)
|
||||
time.Sleep(200 * time.Millisecond) // Wait for cascading deletes
|
||||
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
|
||||
}()
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfigRequest1 := BackupConfig{
|
||||
DatabaseID: database1.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
DatabaseID: database1.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
@@ -891,9 +1179,10 @@ func Test_TransferDatabase_StorageHasOtherDBs_CannotTransfer(t *testing.T) {
|
||||
)
|
||||
|
||||
backupConfigRequest2 := BackupConfig{
|
||||
DatabaseID: database2.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
DatabaseID: database2.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
@@ -945,6 +1234,14 @@ func Test_TransferDatabase_WithNotifiers_NotifiersTransferred(t *testing.T) {
|
||||
targetStorage := createTestStorage(targetWorkspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(sourceWorkspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
|
||||
}()
|
||||
|
||||
database.Notifiers = []notifiers.Notifier{*notifier}
|
||||
var updatedDatabase databases.Database
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
@@ -959,9 +1256,10 @@ func Test_TransferDatabase_WithNotifiers_NotifiersTransferred(t *testing.T) {
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfigRequest := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
@@ -1048,6 +1346,15 @@ func Test_TransferDatabase_NotifierHasOtherDBs_NotifierSkipped(t *testing.T) {
|
||||
targetStorage := createTestStorage(targetWorkspace.ID)
|
||||
sharedNotifier := notifiers.CreateTestNotifier(sourceWorkspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database1)
|
||||
databases.RemoveTestDatabase(database2)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(sharedNotifier)
|
||||
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
|
||||
}()
|
||||
|
||||
database1.Notifiers = []notifiers.Notifier{*sharedNotifier}
|
||||
test_utils.MakePostRequest(
|
||||
t,
|
||||
@@ -1070,9 +1377,10 @@ func Test_TransferDatabase_NotifierHasOtherDBs_NotifierSkipped(t *testing.T) {
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfigRequest := BackupConfig{
|
||||
DatabaseID: database1.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
DatabaseID: database1.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
@@ -1160,6 +1468,16 @@ func Test_TransferDatabase_WithMultipleNotifiers_OnlyExclusiveOnesTransferred(t
|
||||
exclusiveNotifier := notifiers.CreateTestNotifier(sourceWorkspace.ID)
|
||||
sharedNotifier := notifiers.CreateTestNotifier(sourceWorkspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database1)
|
||||
databases.RemoveTestDatabase(database2)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(exclusiveNotifier)
|
||||
notifiers.RemoveTestNotifier(sharedNotifier)
|
||||
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
|
||||
}()
|
||||
|
||||
database1.Notifiers = []notifiers.Notifier{*exclusiveNotifier, *sharedNotifier}
|
||||
test_utils.MakePostRequest(
|
||||
t,
|
||||
@@ -1182,9 +1500,10 @@ func Test_TransferDatabase_WithMultipleNotifiers_OnlyExclusiveOnesTransferred(t
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfigRequest := BackupConfig{
|
||||
DatabaseID: database1.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
DatabaseID: database1.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
@@ -1271,11 +1590,20 @@ func Test_TransferDatabase_WithTargetNotifiers_NotifiersAssigned(t *testing.T) {
|
||||
targetStorage := createTestStorage(targetWorkspace.ID)
|
||||
targetNotifier := notifiers.CreateTestNotifier(targetWorkspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(targetNotifier)
|
||||
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
|
||||
}()
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfigRequest := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
@@ -1342,11 +1670,21 @@ func Test_TransferDatabase_TargetNotifierFromDifferentWorkspace_ReturnsBadReques
|
||||
targetStorage := createTestStorage(targetWorkspace.ID)
|
||||
wrongNotifier := notifiers.CreateTestNotifier(otherWorkspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(wrongNotifier)
|
||||
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(otherWorkspace, router)
|
||||
}()
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfigRequest := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
@@ -1399,11 +1737,20 @@ func Test_TransferDatabase_TargetStorageFromDifferentWorkspace_ReturnsBadRequest
|
||||
sourceStorage := createTestStorage(sourceWorkspace.ID)
|
||||
wrongStorage := createTestStorage(otherWorkspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(otherWorkspace, router)
|
||||
}()
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfigRequest := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
@@ -1443,6 +1790,117 @@ func Test_TransferDatabase_TargetStorageFromDifferentWorkspace_ReturnsBadRequest
|
||||
assert.Contains(t, string(testResp.Body), "target storage does not belong to target workspace")
|
||||
}
|
||||
|
||||
func Test_SaveBackupConfig_WithSystemStorage_CanBeUsedByAnyDatabase(t *testing.T) {
|
||||
router := createTestRouterWithStorageForTransfer()
|
||||
|
||||
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
|
||||
workspaceA := workspaces_testing.CreateTestWorkspace("Workspace A", owner1, router)
|
||||
workspaceB := workspaces_testing.CreateTestWorkspace("Workspace B", owner2, router)
|
||||
|
||||
databaseA := createTestDatabaseViaAPI("Database A", workspaceA.ID, owner1.Token, router)
|
||||
|
||||
// Test 1: Regular storage from workspace B cannot be used by database in workspace A
|
||||
regularStorageB := createTestStorage(workspaceB.ID)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfigWithRegularStorage := BackupConfig{
|
||||
DatabaseID: databaseA.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
StorageID: ®ularStorageB.ID,
|
||||
Storage: regularStorageB,
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
}
|
||||
|
||||
respRegular := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner1.Token,
|
||||
backupConfigWithRegularStorage,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
assert.Contains(t, string(respRegular.Body), "storage does not belong to the same workspace")
|
||||
|
||||
// Test 2: System storage from workspace B CAN be used by database in workspace A
|
||||
systemStorageB := &storages.Storage{
|
||||
WorkspaceID: workspaceB.ID,
|
||||
Type: storages.StorageTypeLocal,
|
||||
Name: "Test System Storage " + uuid.New().String(),
|
||||
IsSystem: true,
|
||||
LocalStorage: &local_storage.LocalStorage{},
|
||||
}
|
||||
|
||||
var savedSystemStorage storages.Storage
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/storages",
|
||||
"Bearer "+admin.Token,
|
||||
*systemStorageB,
|
||||
http.StatusOK,
|
||||
&savedSystemStorage,
|
||||
)
|
||||
|
||||
assert.True(t, savedSystemStorage.IsSystem)
|
||||
|
||||
backupConfigWithSystemStorage := BackupConfig{
|
||||
DatabaseID: databaseA.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
StorageID: &savedSystemStorage.ID,
|
||||
Storage: &savedSystemStorage,
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
}
|
||||
|
||||
var savedConfig BackupConfig
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner1.Token,
|
||||
backupConfigWithSystemStorage,
|
||||
http.StatusOK,
|
||||
&savedConfig,
|
||||
)
|
||||
|
||||
assert.Equal(t, databaseA.ID, savedConfig.DatabaseID)
|
||||
assert.NotNil(t, savedConfig.StorageID)
|
||||
assert.Equal(t, savedSystemStorage.ID, *savedConfig.StorageID)
|
||||
assert.True(t, savedConfig.IsBackupsEnabled)
|
||||
|
||||
// Cleanup: database first (cascades to backup_config), then storages, then workspaces
|
||||
databases.RemoveTestDatabase(databaseA)
|
||||
storages.RemoveTestStorage(regularStorageB.ID)
|
||||
storages.RemoveTestStorage(savedSystemStorage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspaceA, router)
|
||||
workspaces_testing.RemoveTestWorkspace(workspaceB, router)
|
||||
}
|
||||
|
||||
func createTestDatabaseViaAPI(
|
||||
name string,
|
||||
workspaceID uuid.UUID,
|
||||
@@ -1462,7 +1920,7 @@ func createTestDatabaseViaAPI(
|
||||
Type: databases.DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
package backups_config
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
plans "databasus-backend/internal/features/plan"
|
||||
"databasus-backend/internal/features/storages"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var backupConfigRepository = &BackupConfigRepository{}
|
||||
@@ -14,6 +19,7 @@ var backupConfigService = &BackupConfigService{
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
plans.GetDatabasePlanService(),
|
||||
nil,
|
||||
}
|
||||
var backupConfigController = &BackupConfigController{
|
||||
@@ -28,6 +34,21 @@ func GetBackupConfigService() *BackupConfigService {
|
||||
return backupConfigService
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
storages.GetStorageService().SetStorageDatabaseCounter(backupConfigService)
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
storages.GetStorageService().SetStorageDatabaseCounter(backupConfigService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,3 +13,11 @@ const (
|
||||
BackupEncryptionNone BackupEncryption = "NONE"
|
||||
BackupEncryptionEncrypted BackupEncryption = "ENCRYPTED"
|
||||
)
|
||||
|
||||
type RetentionPolicyType string
|
||||
|
||||
const (
|
||||
RetentionPolicyTypeTimePeriod RetentionPolicyType = "TIME_PERIOD"
|
||||
RetentionPolicyTypeCount RetentionPolicyType = "COUNT"
|
||||
RetentionPolicyTypeGFS RetentionPolicyType = "GFS"
|
||||
)
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package backups_config
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
plans "databasus-backend/internal/features/plan"
|
||||
"databasus-backend/internal/features/storages"
|
||||
"databasus-backend/internal/util/period"
|
||||
"errors"
|
||||
@@ -16,7 +18,15 @@ type BackupConfig struct {
|
||||
|
||||
IsBackupsEnabled bool `json:"isBackupsEnabled" gorm:"column:is_backups_enabled;type:boolean;not null"`
|
||||
|
||||
StorePeriod period.Period `json:"storePeriod" gorm:"column:store_period;type:text;not null"`
|
||||
RetentionPolicyType RetentionPolicyType `json:"retentionPolicyType" gorm:"column:retention_policy_type;type:text;not null;default:'TIME_PERIOD'"`
|
||||
RetentionTimePeriod period.TimePeriod `json:"retentionTimePeriod" gorm:"column:retention_time_period;type:text;not null;default:''"`
|
||||
|
||||
RetentionCount int `json:"retentionCount" gorm:"column:retention_count;type:int;not null;default:0"`
|
||||
RetentionGfsHours int `json:"retentionGfsHours" gorm:"column:retention_gfs_hours;type:int;not null;default:0"`
|
||||
RetentionGfsDays int `json:"retentionGfsDays" gorm:"column:retention_gfs_days;type:int;not null;default:0"`
|
||||
RetentionGfsWeeks int `json:"retentionGfsWeeks" gorm:"column:retention_gfs_weeks;type:int;not null;default:0"`
|
||||
RetentionGfsMonths int `json:"retentionGfsMonths" gorm:"column:retention_gfs_months;type:int;not null;default:0"`
|
||||
RetentionGfsYears int `json:"retentionGfsYears" gorm:"column:retention_gfs_years;type:int;not null;default:0"`
|
||||
|
||||
BackupIntervalID uuid.UUID `json:"backupIntervalId" gorm:"column:backup_interval_id;type:uuid;not null"`
|
||||
BackupInterval *intervals.Interval `json:"backupInterval,omitempty" gorm:"foreignKey:BackupIntervalID"`
|
||||
@@ -31,6 +41,11 @@ type BackupConfig struct {
|
||||
MaxFailedTriesCount int `json:"maxFailedTriesCount" gorm:"column:max_failed_tries_count;type:int;not null"`
|
||||
|
||||
Encryption BackupEncryption `json:"encryption" gorm:"column:encryption;type:text;not null;default:'NONE'"`
|
||||
|
||||
// MaxBackupSizeMB limits individual backup size. 0 = unlimited.
|
||||
MaxBackupSizeMB int64 `json:"maxBackupSizeMb" gorm:"column:max_backup_size_mb;type:int;not null"`
|
||||
// MaxBackupsTotalSizeMB limits total size of all backups. 0 = unlimited.
|
||||
MaxBackupsTotalSizeMB int64 `json:"maxBackupsTotalSizeMb" gorm:"column:max_backups_total_size_mb;type:int;not null"`
|
||||
}
|
||||
|
||||
func (h *BackupConfig) TableName() string {
|
||||
@@ -70,14 +85,13 @@ func (b *BackupConfig) AfterFind(tx *gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *BackupConfig) Validate() error {
|
||||
// Backup interval is required either as ID or as object
|
||||
func (b *BackupConfig) Validate(plan *plans.DatabasePlan) error {
|
||||
if b.BackupIntervalID == uuid.Nil && b.BackupInterval == nil {
|
||||
return errors.New("backup interval is required")
|
||||
}
|
||||
|
||||
if b.StorePeriod == "" {
|
||||
return errors.New("store period is required")
|
||||
if err := b.validateRetentionPolicy(plan); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if b.IsRetryIfFailed && b.MaxFailedTriesCount <= 0 {
|
||||
@@ -89,20 +103,87 @@ func (b *BackupConfig) Validate() error {
|
||||
return errors.New("encryption must be NONE or ENCRYPTED")
|
||||
}
|
||||
|
||||
if config.GetEnv().IsCloud {
|
||||
if b.Encryption != BackupEncryptionEncrypted {
|
||||
return errors.New("encryption is mandatory for cloud storage")
|
||||
}
|
||||
}
|
||||
|
||||
if b.MaxBackupSizeMB < 0 {
|
||||
return errors.New("max backup size must be non-negative")
|
||||
}
|
||||
|
||||
if b.MaxBackupsTotalSizeMB < 0 {
|
||||
return errors.New("max backups total size must be non-negative")
|
||||
}
|
||||
|
||||
if plan.MaxBackupSizeMB > 0 {
|
||||
if b.MaxBackupSizeMB == 0 || b.MaxBackupSizeMB > plan.MaxBackupSizeMB {
|
||||
return errors.New("max backup size exceeds plan limit")
|
||||
}
|
||||
}
|
||||
|
||||
if plan.MaxBackupsTotalSizeMB > 0 {
|
||||
if b.MaxBackupsTotalSizeMB == 0 ||
|
||||
b.MaxBackupsTotalSizeMB > plan.MaxBackupsTotalSizeMB {
|
||||
return errors.New("max total backups size exceeds plan limit")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *BackupConfig) Copy(newDatabaseID uuid.UUID) *BackupConfig {
|
||||
return &BackupConfig{
|
||||
DatabaseID: newDatabaseID,
|
||||
IsBackupsEnabled: b.IsBackupsEnabled,
|
||||
StorePeriod: b.StorePeriod,
|
||||
BackupIntervalID: uuid.Nil,
|
||||
BackupInterval: b.BackupInterval.Copy(),
|
||||
StorageID: b.StorageID,
|
||||
SendNotificationsOn: b.SendNotificationsOn,
|
||||
IsRetryIfFailed: b.IsRetryIfFailed,
|
||||
MaxFailedTriesCount: b.MaxFailedTriesCount,
|
||||
Encryption: b.Encryption,
|
||||
DatabaseID: newDatabaseID,
|
||||
IsBackupsEnabled: b.IsBackupsEnabled,
|
||||
RetentionPolicyType: b.RetentionPolicyType,
|
||||
RetentionTimePeriod: b.RetentionTimePeriod,
|
||||
RetentionCount: b.RetentionCount,
|
||||
RetentionGfsHours: b.RetentionGfsHours,
|
||||
RetentionGfsDays: b.RetentionGfsDays,
|
||||
RetentionGfsWeeks: b.RetentionGfsWeeks,
|
||||
RetentionGfsMonths: b.RetentionGfsMonths,
|
||||
RetentionGfsYears: b.RetentionGfsYears,
|
||||
BackupIntervalID: uuid.Nil,
|
||||
BackupInterval: b.BackupInterval.Copy(),
|
||||
StorageID: b.StorageID,
|
||||
SendNotificationsOn: b.SendNotificationsOn,
|
||||
IsRetryIfFailed: b.IsRetryIfFailed,
|
||||
MaxFailedTriesCount: b.MaxFailedTriesCount,
|
||||
Encryption: b.Encryption,
|
||||
MaxBackupSizeMB: b.MaxBackupSizeMB,
|
||||
MaxBackupsTotalSizeMB: b.MaxBackupsTotalSizeMB,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BackupConfig) validateRetentionPolicy(plan *plans.DatabasePlan) error {
|
||||
switch b.RetentionPolicyType {
|
||||
case RetentionPolicyTypeTimePeriod, "":
|
||||
if b.RetentionTimePeriod == "" {
|
||||
return errors.New("retention time period is required")
|
||||
}
|
||||
|
||||
if plan.MaxStoragePeriod != period.PeriodForever {
|
||||
if b.RetentionTimePeriod.CompareTo(plan.MaxStoragePeriod) > 0 {
|
||||
return errors.New("storage period exceeds plan limit")
|
||||
}
|
||||
}
|
||||
|
||||
case RetentionPolicyTypeCount:
|
||||
if b.RetentionCount <= 0 {
|
||||
return errors.New("retention count must be greater than 0")
|
||||
}
|
||||
|
||||
case RetentionPolicyTypeGFS:
|
||||
if b.RetentionGfsHours <= 0 && b.RetentionGfsDays <= 0 && b.RetentionGfsWeeks <= 0 &&
|
||||
b.RetentionGfsMonths <= 0 && b.RetentionGfsYears <= 0 {
|
||||
return errors.New("at least one GFS retention field must be greater than 0")
|
||||
}
|
||||
|
||||
default:
|
||||
return errors.New("invalid retention policy type")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
477
backend/internal/features/backups/config/model_test.go
Normal file
477
backend/internal/features/backups/config/model_test.go
Normal file
@@ -0,0 +1,477 @@
|
||||
package backups_config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"databasus-backend/internal/features/intervals"
|
||||
plans "databasus-backend/internal/features/plan"
|
||||
"databasus-backend/internal/util/period"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_Validate_WhenRetentionTimePeriodIsWeekAndPlanAllowsMonth_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = period.PeriodWeek
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = period.PeriodMonth
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenRetentionTimePeriodIsYearAndPlanAllowsMonth_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = period.PeriodYear
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = period.PeriodMonth
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "storage period exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenRetentionTimePeriodIsForeverAndPlanAllowsForever_ValidationPasses(
|
||||
t *testing.T,
|
||||
) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = period.PeriodForever
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = period.PeriodForever
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenRetentionTimePeriodIsForeverAndPlanAllowsYear_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = period.PeriodForever
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = period.PeriodYear
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "storage period exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenRetentionTimePeriodEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = period.PeriodMonth
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = period.PeriodMonth
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenBackupSize100MBAndPlanAllows500MB_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupSizeMB = 100
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupSizeMB = 500
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenBackupSize500MBAndPlanAllows100MB_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupSizeMB = 500
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupSizeMB = 100
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max backup size exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenBackupSizeIsUnlimitedAndPlanAllowsUnlimited_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupSizeMB = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupSizeMB = 0
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenBackupSizeIsUnlimitedAndPlanHas500MBLimit_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupSizeMB = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupSizeMB = 500
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max backup size exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenBackupSizeEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupSizeMB = 500
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupSizeMB = 500
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenTotalSize1GBAndPlanAllows5GB_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupsTotalSizeMB = 1000
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupsTotalSizeMB = 5000
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenTotalSize5GBAndPlanAllows1GB_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupsTotalSizeMB = 5000
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupsTotalSizeMB = 1000
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max total backups size exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenTotalSizeIsUnlimitedAndPlanAllowsUnlimited_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupsTotalSizeMB = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupsTotalSizeMB = 0
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenTotalSizeIsUnlimitedAndPlanHas1GBLimit_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupsTotalSizeMB = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupsTotalSizeMB = 1000
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max total backups size exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenTotalSizeEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupsTotalSizeMB = 5000
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxBackupsTotalSizeMB = 5000
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenAllLimitsAreUnlimitedInPlan_AnyConfigurationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = period.PeriodForever
|
||||
config.MaxBackupSizeMB = 0
|
||||
config.MaxBackupsTotalSizeMB = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenMultipleLimitsExceeded_ValidationFailsWithFirstError(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = period.PeriodYear
|
||||
config.MaxBackupSizeMB = 500
|
||||
config.MaxBackupsTotalSizeMB = 5000
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = period.PeriodMonth
|
||||
plan.MaxBackupSizeMB = 100
|
||||
plan.MaxBackupsTotalSizeMB = 1000
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.Error(t, err)
|
||||
assert.EqualError(t, err, "storage period exceeds plan limit")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenConfigHasInvalidIntervalButPlanIsValid_ValidationFailsOnInterval(
|
||||
t *testing.T,
|
||||
) {
|
||||
config := createValidBackupConfig()
|
||||
config.BackupIntervalID = uuid.Nil
|
||||
config.BackupInterval = nil
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "backup interval is required")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenIntervalIsMissing_ValidationFailsRegardlessOfPlan(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.BackupIntervalID = uuid.Nil
|
||||
config.BackupInterval = nil
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "backup interval is required")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenRetryEnabledButMaxTriesIsZero_ValidationFailsRegardlessOfPlan(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.IsRetryIfFailed = true
|
||||
config.MaxFailedTriesCount = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max failed tries count must be greater than 0")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenEncryptionIsInvalid_ValidationFailsRegardlessOfPlan(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.Encryption = "INVALID"
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "encryption must be NONE or ENCRYPTED")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenRetentionTimePeriodIsEmpty_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = ""
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "retention time period is required")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenMaxBackupSizeIsNegative_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupSizeMB = -100
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max backup size must be non-negative")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenMaxTotalSizeIsNegative_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.MaxBackupsTotalSizeMB = -1000
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "max backups total size must be non-negative")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenPlanLimitsAreAtBoundary_ValidationWorks(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
configPeriod period.TimePeriod
|
||||
planPeriod period.TimePeriod
|
||||
configSize int64
|
||||
planSize int64
|
||||
configTotal int64
|
||||
planTotal int64
|
||||
shouldSucceed bool
|
||||
}{
|
||||
{
|
||||
name: "all values just under limit",
|
||||
configPeriod: period.PeriodWeek,
|
||||
planPeriod: period.PeriodMonth,
|
||||
configSize: 99,
|
||||
planSize: 100,
|
||||
configTotal: 999,
|
||||
planTotal: 1000,
|
||||
shouldSucceed: true,
|
||||
},
|
||||
{
|
||||
name: "all values equal to limit",
|
||||
configPeriod: period.PeriodMonth,
|
||||
planPeriod: period.PeriodMonth,
|
||||
configSize: 100,
|
||||
planSize: 100,
|
||||
configTotal: 1000,
|
||||
planTotal: 1000,
|
||||
shouldSucceed: true,
|
||||
},
|
||||
{
|
||||
name: "period just over limit",
|
||||
configPeriod: period.Period3Month,
|
||||
planPeriod: period.PeriodMonth,
|
||||
configSize: 100,
|
||||
planSize: 100,
|
||||
configTotal: 1000,
|
||||
planTotal: 1000,
|
||||
shouldSucceed: false,
|
||||
},
|
||||
{
|
||||
name: "size just over limit",
|
||||
configPeriod: period.PeriodMonth,
|
||||
planPeriod: period.PeriodMonth,
|
||||
configSize: 101,
|
||||
planSize: 100,
|
||||
configTotal: 1000,
|
||||
planTotal: 1000,
|
||||
shouldSucceed: false,
|
||||
},
|
||||
{
|
||||
name: "total size just over limit",
|
||||
configPeriod: period.PeriodMonth,
|
||||
planPeriod: period.PeriodMonth,
|
||||
configSize: 100,
|
||||
planSize: 100,
|
||||
configTotal: 1001,
|
||||
planTotal: 1000,
|
||||
shouldSucceed: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionTimePeriod = tt.configPeriod
|
||||
config.MaxBackupSizeMB = tt.configSize
|
||||
config.MaxBackupsTotalSizeMB = tt.configTotal
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
plan.MaxStoragePeriod = tt.planPeriod
|
||||
plan.MaxBackupSizeMB = tt.planSize
|
||||
plan.MaxBackupsTotalSizeMB = tt.planTotal
|
||||
|
||||
err := config.Validate(plan)
|
||||
if tt.shouldSucceed {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Validate_WhenPolicyTypeIsCount_RequiresPositiveCount(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionPolicyType = RetentionPolicyTypeCount
|
||||
config.RetentionCount = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "retention count must be greater than 0")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenPolicyTypeIsCount_WithPositiveCount_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionPolicyType = RetentionPolicyTypeCount
|
||||
config.RetentionCount = 10
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenPolicyTypeIsGFS_RequiresAtLeastOneField(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionPolicyType = RetentionPolicyTypeGFS
|
||||
config.RetentionGfsDays = 0
|
||||
config.RetentionGfsWeeks = 0
|
||||
config.RetentionGfsMonths = 0
|
||||
config.RetentionGfsYears = 0
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "at least one GFS retention field must be greater than 0")
|
||||
}
|
||||
|
||||
func Test_Validate_WhenPolicyTypeIsGFS_WithOnlyHours_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionPolicyType = RetentionPolicyTypeGFS
|
||||
config.RetentionGfsHours = 24
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenPolicyTypeIsGFS_WithOnlyDays_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionPolicyType = RetentionPolicyTypeGFS
|
||||
config.RetentionGfsDays = 7
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenPolicyTypeIsGFS_WithAllFields_ValidationPasses(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionPolicyType = RetentionPolicyTypeGFS
|
||||
config.RetentionGfsHours = 24
|
||||
config.RetentionGfsDays = 7
|
||||
config.RetentionGfsWeeks = 4
|
||||
config.RetentionGfsMonths = 12
|
||||
config.RetentionGfsYears = 3
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenPolicyTypeIsInvalid_ValidationFails(t *testing.T) {
|
||||
config := createValidBackupConfig()
|
||||
config.RetentionPolicyType = "INVALID"
|
||||
|
||||
plan := createUnlimitedPlan()
|
||||
|
||||
err := config.Validate(plan)
|
||||
assert.EqualError(t, err, "invalid retention policy type")
|
||||
}
|
||||
|
||||
func createValidBackupConfig() *BackupConfig {
|
||||
intervalID := uuid.New()
|
||||
return &BackupConfig{
|
||||
DatabaseID: uuid.New(),
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodMonth,
|
||||
BackupIntervalID: intervalID,
|
||||
BackupInterval: &intervals.Interval{ID: intervalID},
|
||||
SendNotificationsOn: []BackupNotificationType{},
|
||||
IsRetryIfFailed: false,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
MaxBackupSizeMB: 100,
|
||||
MaxBackupsTotalSizeMB: 1000,
|
||||
}
|
||||
}
|
||||
|
||||
func createUnlimitedPlan() *plans.DatabasePlan {
|
||||
return &plans.DatabasePlan{
|
||||
DatabaseID: uuid.New(),
|
||||
MaxBackupSizeMB: 0,
|
||||
MaxBackupsTotalSizeMB: 0,
|
||||
MaxStoragePeriod: period.PeriodForever,
|
||||
}
|
||||
}
|
||||
@@ -26,6 +26,12 @@ func Test_AttachNotifierFromSameWorkspace_SuccessfullyAttached(t *testing.T) {
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
database.Notifiers = []notifiers.Notifier{*notifier}
|
||||
|
||||
var response databases.Database
|
||||
@@ -55,6 +61,13 @@ func Test_AttachNotifierFromDifferentWorkspace_ReturnsForbidden(t *testing.T) {
|
||||
workspace2 := workspaces_testing.CreateTestWorkspace("Workspace 2", owner2, router)
|
||||
notifier := notifiers.CreateTestNotifier(workspace2.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace1, router)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace2, router)
|
||||
}()
|
||||
|
||||
database.Notifiers = []notifiers.Notifier{*notifier}
|
||||
|
||||
testResp := test_utils.MakePostRequest(
|
||||
@@ -77,6 +90,12 @@ func Test_DeleteNotifierWithAttachedDatabases_CannotDelete(t *testing.T) {
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
database.Notifiers = []notifiers.Notifier{*notifier}
|
||||
|
||||
var response databases.Database
|
||||
@@ -114,6 +133,13 @@ func Test_TransferNotifierWithAttachedDatabase_CannotTransfer(t *testing.T) {
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
|
||||
}()
|
||||
|
||||
database.Notifiers = []notifiers.Notifier{*notifier}
|
||||
|
||||
var response databases.Database
|
||||
|
||||
@@ -6,10 +6,10 @@ import (
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
plans "databasus-backend/internal/features/plan"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_models "databasus-backend/internal/features/users/models"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/period"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -20,6 +20,7 @@ type BackupConfigService struct {
|
||||
storageService *storages.StorageService
|
||||
notifierService *notifiers.NotifierService
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
databasePlanService *plans.DatabasePlanService
|
||||
|
||||
dbStorageChangeListener BackupConfigStorageChangeListener
|
||||
}
|
||||
@@ -45,7 +46,12 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
|
||||
user *users_models.User,
|
||||
backupConfig *BackupConfig,
|
||||
) (*BackupConfig, error) {
|
||||
if err := backupConfig.Validate(); err != nil {
|
||||
plan, err := s.databasePlanService.GetDatabasePlan(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := backupConfig.Validate(plan); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -71,7 +77,7 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if storage.WorkspaceID != *database.WorkspaceID {
|
||||
if storage.WorkspaceID != *database.WorkspaceID && !storage.IsSystem {
|
||||
return nil, errors.New("storage does not belong to the same workspace as the database")
|
||||
}
|
||||
}
|
||||
@@ -82,7 +88,12 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
|
||||
func (s *BackupConfigService) SaveBackupConfig(
|
||||
backupConfig *BackupConfig,
|
||||
) (*BackupConfig, error) {
|
||||
if err := backupConfig.Validate(); err != nil {
|
||||
plan, err := s.databasePlanService.GetDatabasePlan(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := backupConfig.Validate(plan); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -120,6 +131,18 @@ func (s *BackupConfigService) GetBackupConfigByDbIdWithAuth(
|
||||
return s.GetBackupConfigByDbId(databaseID)
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) GetDatabasePlan(
|
||||
user *users_models.User,
|
||||
databaseID uuid.UUID,
|
||||
) (*plans.DatabasePlan, error) {
|
||||
_, err := s.databaseService.GetDatabase(user, databaseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.databasePlanService.GetDatabasePlan(databaseID)
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) GetBackupConfigByDbId(
|
||||
databaseID uuid.UUID,
|
||||
) (*BackupConfig, error) {
|
||||
@@ -194,12 +217,20 @@ func (s *BackupConfigService) CreateDisabledBackupConfig(databaseID uuid.UUID) e
|
||||
func (s *BackupConfigService) initializeDefaultConfig(
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
plan, err := s.databasePlanService.GetDatabasePlan(databaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
timeOfDay := "04:00"
|
||||
|
||||
_, err := s.backupConfigRepository.Save(&BackupConfig{
|
||||
DatabaseID: databaseID,
|
||||
IsBackupsEnabled: false,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
_, err = s.backupConfigRepository.Save(&BackupConfig{
|
||||
DatabaseID: databaseID,
|
||||
IsBackupsEnabled: false,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: plan.MaxStoragePeriod,
|
||||
MaxBackupSizeMB: plan.MaxBackupSizeMB,
|
||||
MaxBackupsTotalSizeMB: plan.MaxBackupsTotalSizeMB,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
|
||||
@@ -27,11 +27,18 @@ func Test_AttachStorageFromSameWorkspace_SuccessfullyAttached(t *testing.T) {
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
timeOfDay := "04:00"
|
||||
request := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
@@ -72,11 +79,19 @@ func Test_AttachStorageFromDifferentWorkspace_ReturnsForbidden(t *testing.T) {
|
||||
workspace2 := workspaces_testing.CreateTestWorkspace("Workspace 2", owner2, router)
|
||||
storage := createTestStorage(workspace2.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace1, router)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace2, router)
|
||||
}()
|
||||
|
||||
timeOfDay := "04:00"
|
||||
request := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
@@ -110,11 +125,18 @@ func Test_DeleteStorageWithAttachedDatabases_CannotDelete(t *testing.T) {
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
timeOfDay := "04:00"
|
||||
request := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
@@ -163,11 +185,19 @@ func Test_TransferStorageWithAttachedDatabase_CannotTransfer(t *testing.T) {
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
defer func() {
|
||||
databases.RemoveTestDatabase(database)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
|
||||
}()
|
||||
|
||||
timeOfDay := "04:00"
|
||||
request := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
|
||||
@@ -15,9 +15,10 @@ func EnableBackupsForTestDatabase(
|
||||
timeOfDay := "16:00"
|
||||
|
||||
backupConfig := &BackupConfig{
|
||||
DatabaseID: databaseID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodDay,
|
||||
DatabaseID: databaseID,
|
||||
IsBackupsEnabled: true,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: period.PeriodDay,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
|
||||
@@ -25,80 +25,6 @@ import (
|
||||
"databasus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
router := workspaces_testing.CreateTestRouter(
|
||||
workspaces_controllers.GetWorkspaceController(),
|
||||
workspaces_controllers.GetMembershipController(),
|
||||
GetDatabaseController(),
|
||||
)
|
||||
return router
|
||||
}
|
||||
|
||||
func getTestPostgresConfig() *postgresql.PostgresqlDatabase {
|
||||
env := config.GetEnv()
|
||||
port, err := strconv.Atoi(env.TestPostgres16Port)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
|
||||
}
|
||||
|
||||
testDbName := "testdb"
|
||||
return &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func getTestMariadbConfig() *mariadb.MariadbDatabase {
|
||||
env := config.GetEnv()
|
||||
portStr := env.TestMariadb1011Port
|
||||
if portStr == "" {
|
||||
portStr = "33111"
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_MARIADB_1011_PORT: %v", err))
|
||||
}
|
||||
|
||||
testDbName := "testdb"
|
||||
return &mariadb.MariadbDatabase{
|
||||
Version: tools.MariadbVersion1011,
|
||||
Host: "localhost",
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
Database: &testDbName,
|
||||
}
|
||||
}
|
||||
|
||||
func getTestMongodbConfig() *mongodb.MongodbDatabase {
|
||||
env := config.GetEnv()
|
||||
portStr := env.TestMongodb70Port
|
||||
if portStr == "" {
|
||||
portStr = "27070"
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_MONGODB_70_PORT: %v", err))
|
||||
}
|
||||
|
||||
return &mongodb.MongodbDatabase{
|
||||
Version: tools.MongodbVersion7,
|
||||
Host: "localhost",
|
||||
Port: port,
|
||||
Username: "root",
|
||||
Password: "rootpassword",
|
||||
Database: "testdb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CreateDatabase_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -142,6 +68,7 @@ func Test_CreateDatabase_PermissionsEnforced(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
@@ -180,6 +107,7 @@ func Test_CreateDatabase_PermissionsEnforced(t *testing.T) {
|
||||
)
|
||||
|
||||
if tt.expectSuccess {
|
||||
defer RemoveTestDatabase(&response)
|
||||
assert.Equal(t, "Test Database", response.Name)
|
||||
assert.NotEqual(t, uuid.Nil, response.ID)
|
||||
} else {
|
||||
@@ -193,6 +121,7 @@ func Test_CreateDatabase_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testin
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
|
||||
@@ -258,8 +187,10 @@ func Test_UpdateDatabase_PermissionsEnforced(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
defer RemoveTestDatabase(database)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
@@ -305,8 +236,10 @@ func Test_UpdateDatabase_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testin
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
defer RemoveTestDatabase(database)
|
||||
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
database.Name = "Hacked Name"
|
||||
@@ -366,6 +299,7 @@ func Test_DeleteDatabase_PermissionsEnforced(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
@@ -396,6 +330,7 @@ func Test_DeleteDatabase_PermissionsEnforced(t *testing.T) {
|
||||
)
|
||||
|
||||
if !tt.expectSuccess {
|
||||
defer RemoveTestDatabase(database)
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
})
|
||||
@@ -439,8 +374,10 @@ func Test_GetDatabase_PermissionsEnforced(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
defer RemoveTestDatabase(database)
|
||||
|
||||
var testUser string
|
||||
if tt.isGlobalAdmin {
|
||||
@@ -517,9 +454,12 @@ func Test_GetDatabasesByWorkspace_PermissionsEnforced(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
createTestDatabaseViaAPI("Database 1", workspace.ID, owner.Token, router)
|
||||
createTestDatabaseViaAPI("Database 2", workspace.ID, owner.Token, router)
|
||||
db1 := createTestDatabaseViaAPI("Database 1", workspace.ID, owner.Token, router)
|
||||
defer RemoveTestDatabase(db1)
|
||||
db2 := createTestDatabaseViaAPI("Database 2", workspace.ID, owner.Token, router)
|
||||
defer RemoveTestDatabase(db2)
|
||||
|
||||
var testUser string
|
||||
if tt.isGlobalAdmin {
|
||||
@@ -561,10 +501,14 @@ func Test_GetDatabasesByWorkspace_WhenMultipleDatabasesExist_ReturnsCorrectCount
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
createTestDatabaseViaAPI("Database 1", workspace.ID, owner.Token, router)
|
||||
createTestDatabaseViaAPI("Database 2", workspace.ID, owner.Token, router)
|
||||
createTestDatabaseViaAPI("Database 3", workspace.ID, owner.Token, router)
|
||||
db1 := createTestDatabaseViaAPI("Database 1", workspace.ID, owner.Token, router)
|
||||
defer RemoveTestDatabase(db1)
|
||||
db2 := createTestDatabaseViaAPI("Database 2", workspace.ID, owner.Token, router)
|
||||
defer RemoveTestDatabase(db2)
|
||||
db3 := createTestDatabaseViaAPI("Database 3", workspace.ID, owner.Token, router)
|
||||
defer RemoveTestDatabase(db3)
|
||||
|
||||
var response []Database
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
@@ -583,14 +527,19 @@ func Test_GetDatabasesByWorkspace_EnsuresCrossWorkspaceIsolation(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace1 := workspaces_testing.CreateTestWorkspace("Workspace 1", owner1, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace1, router)
|
||||
|
||||
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace2 := workspaces_testing.CreateTestWorkspace("Workspace 2", owner2, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace2, router)
|
||||
|
||||
createTestDatabaseViaAPI("Workspace1 DB1", workspace1.ID, owner1.Token, router)
|
||||
createTestDatabaseViaAPI("Workspace1 DB2", workspace1.ID, owner1.Token, router)
|
||||
workspace1Db1 := createTestDatabaseViaAPI("Workspace1 DB1", workspace1.ID, owner1.Token, router)
|
||||
defer RemoveTestDatabase(workspace1Db1)
|
||||
workspace1Db2 := createTestDatabaseViaAPI("Workspace1 DB2", workspace1.ID, owner1.Token, router)
|
||||
defer RemoveTestDatabase(workspace1Db2)
|
||||
|
||||
createTestDatabaseViaAPI("Workspace2 DB1", workspace2.ID, owner2.Token, router)
|
||||
workspace2Db1 := createTestDatabaseViaAPI("Workspace2 DB1", workspace2.ID, owner2.Token, router)
|
||||
defer RemoveTestDatabase(workspace2Db1)
|
||||
|
||||
var workspace1Dbs []Database
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
@@ -667,8 +616,10 @@ func Test_CopyDatabase_PermissionsEnforced(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
defer RemoveTestDatabase(database)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
@@ -700,6 +651,7 @@ func Test_CopyDatabase_PermissionsEnforced(t *testing.T) {
|
||||
)
|
||||
|
||||
if tt.expectSuccess {
|
||||
defer RemoveTestDatabase(&response)
|
||||
assert.NotEqual(t, database.ID, response.ID)
|
||||
assert.Contains(t, response.Name, "(Copy)")
|
||||
} else {
|
||||
@@ -713,8 +665,10 @@ func Test_CopyDatabase_CopyStaysInSameWorkspace(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
defer RemoveTestDatabase(database)
|
||||
|
||||
var response Database
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
@@ -727,139 +681,14 @@ func Test_CopyDatabase_CopyStaysInSameWorkspace(t *testing.T) {
|
||||
&response,
|
||||
)
|
||||
|
||||
defer RemoveTestDatabase(&response)
|
||||
|
||||
assert.NotEqual(t, database.ID, response.ID)
|
||||
assert.Equal(t, "Test Database (Copy)", response.Name)
|
||||
assert.Equal(t, workspace.ID, *response.WorkspaceID)
|
||||
assert.Equal(t, database.Type, response.Type)
|
||||
}
|
||||
|
||||
func Test_TestConnection_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
isMember bool
|
||||
isGlobalAdmin bool
|
||||
expectAccessGranted bool
|
||||
expectedStatusCodeOnErr int
|
||||
}{
|
||||
{
|
||||
name: "workspace member can test connection",
|
||||
isMember: true,
|
||||
isGlobalAdmin: false,
|
||||
expectAccessGranted: true,
|
||||
expectedStatusCodeOnErr: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "non-member cannot test connection",
|
||||
isMember: false,
|
||||
isGlobalAdmin: false,
|
||||
expectAccessGranted: false,
|
||||
expectedStatusCodeOnErr: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "global admin can test connection",
|
||||
isMember: false,
|
||||
isGlobalAdmin: true,
|
||||
expectAccessGranted: true,
|
||||
expectedStatusCodeOnErr: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
var testUser string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
testUser = admin.Token
|
||||
} else if tt.isMember {
|
||||
testUser = owner.Token
|
||||
} else {
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
testUser = nonMember.Token
|
||||
}
|
||||
|
||||
w := workspaces_testing.MakeAPIRequest(
|
||||
router,
|
||||
"POST",
|
||||
"/api/v1/databases/"+database.ID.String()+"/test-connection",
|
||||
"Bearer "+testUser,
|
||||
nil,
|
||||
)
|
||||
|
||||
body := w.Body.String()
|
||||
|
||||
if tt.expectAccessGranted {
|
||||
assert.True(
|
||||
t,
|
||||
w.Code == http.StatusOK ||
|
||||
(w.Code == http.StatusBadRequest && strings.Contains(body, "connect")),
|
||||
"Expected 200 OK or 400 with connection error, got %d: %s",
|
||||
w.Code,
|
||||
body,
|
||||
)
|
||||
} else {
|
||||
assert.Equal(t, tt.expectedStatusCodeOnErr, w.Code)
|
||||
assert.Contains(t, body, "insufficient permissions")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createTestDatabaseViaAPI(
|
||||
name string,
|
||||
workspaceID uuid.UUID,
|
||||
token string,
|
||||
router *gin.Engine,
|
||||
) *Database {
|
||||
env := config.GetEnv()
|
||||
port, err := strconv.Atoi(env.TestPostgres16Port)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
|
||||
}
|
||||
|
||||
testDbName := "testdb"
|
||||
request := Database{
|
||||
Name: name,
|
||||
WorkspaceID: &workspaceID,
|
||||
Type: DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
},
|
||||
}
|
||||
|
||||
w := workspaces_testing.MakeAPIRequest(
|
||||
router,
|
||||
"POST",
|
||||
"/api/v1/databases/create",
|
||||
"Bearer "+token,
|
||||
request,
|
||||
)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
panic(
|
||||
fmt.Sprintf("Failed to create database. Status: %d, Body: %s", w.Code, w.Body.String()),
|
||||
)
|
||||
}
|
||||
|
||||
var database Database
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return &database
|
||||
}
|
||||
|
||||
func Test_CreateDatabase_PasswordIsEncryptedInDB(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
@@ -1141,3 +970,207 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_TestConnection_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
isMember bool
|
||||
isGlobalAdmin bool
|
||||
expectAccessGranted bool
|
||||
expectedStatusCodeOnErr int
|
||||
}{
|
||||
{
|
||||
name: "workspace member can test connection",
|
||||
isMember: true,
|
||||
isGlobalAdmin: false,
|
||||
expectAccessGranted: true,
|
||||
expectedStatusCodeOnErr: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "non-member cannot test connection",
|
||||
isMember: false,
|
||||
isGlobalAdmin: false,
|
||||
expectAccessGranted: false,
|
||||
expectedStatusCodeOnErr: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "global admin can test connection",
|
||||
isMember: false,
|
||||
isGlobalAdmin: true,
|
||||
expectAccessGranted: true,
|
||||
expectedStatusCodeOnErr: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
defer RemoveTestDatabase(database)
|
||||
|
||||
var testUser string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
testUser = admin.Token
|
||||
} else if tt.isMember {
|
||||
testUser = owner.Token
|
||||
} else {
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
testUser = nonMember.Token
|
||||
}
|
||||
|
||||
w := workspaces_testing.MakeAPIRequest(
|
||||
router,
|
||||
"POST",
|
||||
"/api/v1/databases/"+database.ID.String()+"/test-connection",
|
||||
"Bearer "+testUser,
|
||||
nil,
|
||||
)
|
||||
|
||||
body := w.Body.String()
|
||||
|
||||
if tt.expectAccessGranted {
|
||||
assert.True(
|
||||
t,
|
||||
w.Code == http.StatusOK ||
|
||||
(w.Code == http.StatusBadRequest && strings.Contains(body, "connect")),
|
||||
"Expected 200 OK or 400 with connection error, got %d: %s",
|
||||
w.Code,
|
||||
body,
|
||||
)
|
||||
} else {
|
||||
assert.Equal(t, tt.expectedStatusCodeOnErr, w.Code)
|
||||
assert.Contains(t, body, "insufficient permissions")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createTestDatabaseViaAPI(
|
||||
name string,
|
||||
workspaceID uuid.UUID,
|
||||
token string,
|
||||
router *gin.Engine,
|
||||
) *Database {
|
||||
env := config.GetEnv()
|
||||
port, err := strconv.Atoi(env.TestPostgres16Port)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
|
||||
}
|
||||
|
||||
testDbName := "testdb"
|
||||
request := Database{
|
||||
Name: name,
|
||||
WorkspaceID: &workspaceID,
|
||||
Type: DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
},
|
||||
}
|
||||
|
||||
w := workspaces_testing.MakeAPIRequest(
|
||||
router,
|
||||
"POST",
|
||||
"/api/v1/databases/create",
|
||||
"Bearer "+token,
|
||||
request,
|
||||
)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
panic(
|
||||
fmt.Sprintf("Failed to create database. Status: %d, Body: %s", w.Code, w.Body.String()),
|
||||
)
|
||||
}
|
||||
|
||||
var database Database
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return &database
|
||||
}
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
router := workspaces_testing.CreateTestRouter(
|
||||
workspaces_controllers.GetWorkspaceController(),
|
||||
workspaces_controllers.GetMembershipController(),
|
||||
GetDatabaseController(),
|
||||
)
|
||||
return router
|
||||
}
|
||||
|
||||
func getTestPostgresConfig() *postgresql.PostgresqlDatabase {
|
||||
env := config.GetEnv()
|
||||
port, err := strconv.Atoi(env.TestPostgres16Port)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
|
||||
}
|
||||
|
||||
testDbName := "testdb"
|
||||
return &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func getTestMariadbConfig() *mariadb.MariadbDatabase {
|
||||
env := config.GetEnv()
|
||||
portStr := env.TestMariadb1011Port
|
||||
if portStr == "" {
|
||||
portStr = "33111"
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_MARIADB_1011_PORT: %v", err))
|
||||
}
|
||||
|
||||
testDbName := "testdb"
|
||||
return &mariadb.MariadbDatabase{
|
||||
Version: tools.MariadbVersion1011,
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
Database: &testDbName,
|
||||
}
|
||||
}
|
||||
|
||||
func getTestMongodbConfig() *mongodb.MongodbDatabase {
|
||||
env := config.GetEnv()
|
||||
portStr := env.TestMongodb70Port
|
||||
if portStr == "" {
|
||||
portStr = "27070"
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_MONGODB_70_PORT: %v", err))
|
||||
}
|
||||
|
||||
return &mongodb.MongodbDatabase{
|
||||
Version: tools.MongodbVersion7,
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: &port,
|
||||
Username: "root",
|
||||
Password: "rootpassword",
|
||||
Database: "testdb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,13 +25,14 @@ type MariadbDatabase struct {
|
||||
|
||||
Version tools.MariadbVersion `json:"version" gorm:"type:text;not null"`
|
||||
|
||||
Host string `json:"host" gorm:"type:text;not null"`
|
||||
Port int `json:"port" gorm:"type:int;not null"`
|
||||
Username string `json:"username" gorm:"type:text;not null"`
|
||||
Password string `json:"password" gorm:"type:text;not null"`
|
||||
Database *string `json:"database" gorm:"type:text"`
|
||||
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
|
||||
Privileges string `json:"privileges" gorm:"column:privileges;type:text;not null;default:''"`
|
||||
Host string `json:"host" gorm:"type:text;not null"`
|
||||
Port int `json:"port" gorm:"type:int;not null"`
|
||||
Username string `json:"username" gorm:"type:text;not null"`
|
||||
Password string `json:"password" gorm:"type:text;not null"`
|
||||
Database *string `json:"database" gorm:"type:text"`
|
||||
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
|
||||
IsExcludeEvents bool `json:"isExcludeEvents" gorm:"type:boolean;default:false"`
|
||||
Privileges string `json:"privileges" gorm:"column:privileges;type:text;not null;default:''"`
|
||||
}
|
||||
|
||||
func (m *MariadbDatabase) TableName() string {
|
||||
@@ -124,6 +125,7 @@ func (m *MariadbDatabase) Update(incoming *MariadbDatabase) {
|
||||
m.Username = incoming.Username
|
||||
m.Database = incoming.Database
|
||||
m.IsHttps = incoming.IsHttps
|
||||
m.IsExcludeEvents = incoming.IsExcludeEvents
|
||||
m.Privileges = incoming.Privileges
|
||||
|
||||
if incoming.Password != "" {
|
||||
@@ -515,9 +517,13 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
|
||||
hasProcess := false
|
||||
hasAllPrivileges := false
|
||||
|
||||
// Escape underscores to match MariaDB's grant output format
|
||||
// MariaDB escapes _ as \_ in SHOW GRANTS output
|
||||
// Pattern matches either literal _ or escaped \_
|
||||
escapedDbName := strings.ReplaceAll(regexp.QuoteMeta(database), "_", `(_|\\_)`)
|
||||
dbPatternStr := fmt.Sprintf(
|
||||
`(?i)ON\s+[\x60'"]?%s[\x60'"]?\s*\.\s*\*`,
|
||||
regexp.QuoteMeta(database),
|
||||
escapedDbName,
|
||||
)
|
||||
dbPattern := regexp.MustCompile(dbPatternStr)
|
||||
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\s*\.\s*\*`)
|
||||
|
||||
@@ -694,6 +694,115 @@ func Test_TestConnection_DatabaseWithUnderscores_Success(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_TestConnection_DatabaseWithUnderscoresAndAllPrivileges_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MariadbVersion
|
||||
port string
|
||||
}{
|
||||
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
|
||||
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
|
||||
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
|
||||
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
|
||||
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
|
||||
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
|
||||
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
|
||||
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
|
||||
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
|
||||
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
|
||||
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMariadbContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
underscoreDbName := "test_all_db"
|
||||
|
||||
_, err := container.DB.Exec(
|
||||
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", underscoreDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(
|
||||
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
|
||||
)
|
||||
}()
|
||||
|
||||
underscoreDSN := fmt.Sprintf(
|
||||
"%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
container.Username,
|
||||
container.Password,
|
||||
container.Host,
|
||||
container.Port,
|
||||
underscoreDbName,
|
||||
)
|
||||
underscoreDB, err := sqlx.Connect("mysql", underscoreDSN)
|
||||
assert.NoError(t, err)
|
||||
defer underscoreDB.Close()
|
||||
|
||||
_, err = underscoreDB.Exec(`
|
||||
CREATE TABLE all_priv_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(`INSERT INTO all_priv_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
allPrivUsername := fmt.Sprintf("allpriv%s", uuid.New().String()[:8])
|
||||
allPrivPassword := "allprivpass123"
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
allPrivUsername,
|
||||
allPrivPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"GRANT ALL PRIVILEGES ON `%s`.* TO '%s'@'%%'",
|
||||
underscoreDbName,
|
||||
allPrivUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer dropUserSafe(underscoreDB, allPrivUsername)
|
||||
|
||||
mariadbModel := &MariadbDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: allPrivUsername,
|
||||
Password: allPrivPassword,
|
||||
Database: &underscoreDbName,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mariadbModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, mariadbModel.Privileges)
|
||||
assert.Contains(t, mariadbModel.Privileges, "SELECT")
|
||||
assert.Contains(t, mariadbModel.Privileges, "SHOW VIEW")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type MariadbContainer struct {
|
||||
Host string
|
||||
Port int
|
||||
@@ -714,7 +823,7 @@ func connectToMariadbContainer(
|
||||
}
|
||||
|
||||
dbName := "testdb"
|
||||
host := "127.0.0.1"
|
||||
host := config.GetEnv().TestLocalhost
|
||||
username := "root"
|
||||
password := "rootpassword"
|
||||
|
||||
|
||||
@@ -26,12 +26,13 @@ type MongodbDatabase struct {
|
||||
Version tools.MongodbVersion `json:"version" gorm:"type:text;not null"`
|
||||
|
||||
Host string `json:"host" gorm:"type:text;not null"`
|
||||
Port int `json:"port" gorm:"type:int;not null"`
|
||||
Port *int `json:"port" gorm:"type:int"`
|
||||
Username string `json:"username" gorm:"type:text;not null"`
|
||||
Password string `json:"password" gorm:"type:text;not null"`
|
||||
Database string `json:"database" gorm:"type:text;not null"`
|
||||
AuthDatabase string `json:"authDatabase" gorm:"type:text;not null;default:'admin'"`
|
||||
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
|
||||
IsSrv bool `json:"isSrv" gorm:"column:is_srv;type:boolean;not null;default:false"`
|
||||
CpuCount int `json:"cpuCount" gorm:"column:cpu_count;type:int;not null;default:1"`
|
||||
}
|
||||
|
||||
@@ -43,9 +44,13 @@ func (m *MongodbDatabase) Validate() error {
|
||||
if m.Host == "" {
|
||||
return errors.New("host is required")
|
||||
}
|
||||
if m.Port == 0 {
|
||||
return errors.New("port is required")
|
||||
|
||||
if !m.IsSrv {
|
||||
if m.Port == nil || *m.Port == 0 {
|
||||
return errors.New("port is required for standard connections")
|
||||
}
|
||||
}
|
||||
|
||||
if m.Username == "" {
|
||||
return errors.New("username is required")
|
||||
}
|
||||
@@ -58,6 +63,7 @@ func (m *MongodbDatabase) Validate() error {
|
||||
if m.CpuCount <= 0 {
|
||||
return errors.New("cpu count must be greater than 0")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -125,6 +131,7 @@ func (m *MongodbDatabase) Update(incoming *MongodbDatabase) {
|
||||
m.Database = incoming.Database
|
||||
m.AuthDatabase = incoming.AuthDatabase
|
||||
m.IsHttps = incoming.IsHttps
|
||||
m.IsSrv = incoming.IsSrv
|
||||
m.CpuCount = incoming.CpuCount
|
||||
|
||||
if incoming.Password != "" {
|
||||
@@ -455,12 +462,29 @@ func (m *MongodbDatabase) buildConnectionURI(password string) string {
|
||||
tlsParams = "&tls=true&tlsInsecure=true"
|
||||
}
|
||||
|
||||
if m.IsSrv {
|
||||
return fmt.Sprintf(
|
||||
"mongodb+srv://%s:%s@%s/%s?authSource=%s&connectTimeoutMS=15000%s",
|
||||
url.QueryEscape(m.Username),
|
||||
url.QueryEscape(password),
|
||||
m.Host,
|
||||
m.Database,
|
||||
authDB,
|
||||
tlsParams,
|
||||
)
|
||||
}
|
||||
|
||||
port := 27017
|
||||
if m.Port != nil {
|
||||
port = *m.Port
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"mongodb://%s:%s@%s:%d/%s?authSource=%s&connectTimeoutMS=15000%s",
|
||||
url.QueryEscape(m.Username),
|
||||
url.QueryEscape(password),
|
||||
m.Host,
|
||||
m.Port,
|
||||
port,
|
||||
m.Database,
|
||||
authDB,
|
||||
tlsParams,
|
||||
@@ -479,12 +503,28 @@ func (m *MongodbDatabase) BuildMongodumpURI(password string) string {
|
||||
tlsParams = "&tls=true&tlsInsecure=true"
|
||||
}
|
||||
|
||||
if m.IsSrv {
|
||||
return fmt.Sprintf(
|
||||
"mongodb+srv://%s:%s@%s/?authSource=%s&connectTimeoutMS=15000%s",
|
||||
url.QueryEscape(m.Username),
|
||||
url.QueryEscape(password),
|
||||
m.Host,
|
||||
authDB,
|
||||
tlsParams,
|
||||
)
|
||||
}
|
||||
|
||||
port := 27017
|
||||
if m.Port != nil {
|
||||
port = *m.Port
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"mongodb://%s:%s@%s:%d/?authSource=%s&connectTimeoutMS=15000%s",
|
||||
url.QueryEscape(m.Username),
|
||||
url.QueryEscape(password),
|
||||
m.Host,
|
||||
m.Port,
|
||||
port,
|
||||
authDB,
|
||||
tlsParams,
|
||||
)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -63,15 +64,17 @@ func Test_TestConnection_InsufficientPermissions_ReturnsError(t *testing.T) {
|
||||
|
||||
defer dropUserSafe(container.Client, limitedUsername, container.AuthDatabase)
|
||||
|
||||
port := container.Port
|
||||
mongodbModel := &MongodbDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Port: &port,
|
||||
Username: limitedUsername,
|
||||
Password: limitedPassword,
|
||||
Database: container.Database,
|
||||
AuthDatabase: container.AuthDatabase,
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
@@ -132,15 +135,17 @@ func Test_TestConnection_SufficientPermissions_Success(t *testing.T) {
|
||||
|
||||
defer dropUserSafe(container.Client, backupUsername, container.AuthDatabase)
|
||||
|
||||
port := container.Port
|
||||
mongodbModel := &MongodbDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Port: &port,
|
||||
Username: backupUsername,
|
||||
Password: backupPassword,
|
||||
Database: container.Database,
|
||||
AuthDatabase: container.AuthDatabase,
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
@@ -397,7 +402,7 @@ func connectToMongodbContainer(
|
||||
}
|
||||
|
||||
dbName := "testdb"
|
||||
host := "127.0.0.1"
|
||||
host := config.GetEnv().TestLocalhost
|
||||
username := "root"
|
||||
password := "rootpassword"
|
||||
authDatabase := "admin"
|
||||
@@ -406,11 +411,18 @@ func connectToMongodbContainer(
|
||||
assert.NoError(t, err)
|
||||
|
||||
uri := fmt.Sprintf(
|
||||
"mongodb://%s:%s@%s:%d/%s?authSource=%s",
|
||||
username, password, host, portInt, dbName, authDatabase,
|
||||
"mongodb://%s:%s@%s:%d/%s?authSource=%s&serverSelectionTimeoutMS=5000&connectTimeoutMS=5000",
|
||||
username,
|
||||
password,
|
||||
host,
|
||||
portInt,
|
||||
dbName,
|
||||
authDatabase,
|
||||
)
|
||||
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
clientOptions := options.Client().ApplyURI(uri)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
if err != nil {
|
||||
@@ -434,15 +446,17 @@ func connectToMongodbContainer(
|
||||
}
|
||||
|
||||
func createMongodbModel(container *MongodbContainer) *MongodbDatabase {
|
||||
port := container.Port
|
||||
return &MongodbDatabase{
|
||||
Version: container.Version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Port: &port,
|
||||
Username: container.Username,
|
||||
Password: container.Password,
|
||||
Database: container.Database,
|
||||
AuthDatabase: container.AuthDatabase,
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
}
|
||||
@@ -481,3 +495,157 @@ func assertWriteDenied(t *testing.T, err error) {
|
||||
strings.Contains(errStr, "permission denied"),
|
||||
"Expected authorization error, got: %v", err)
|
||||
}
|
||||
|
||||
func Test_BuildConnectionURI_WithSrvFormat_ReturnsCorrectUri(t *testing.T) {
|
||||
port := 27017
|
||||
model := &MongodbDatabase{
|
||||
Host: "cluster0.example.mongodb.net",
|
||||
Port: &port,
|
||||
Username: "testuser",
|
||||
Password: "testpass123",
|
||||
Database: "mydb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: true,
|
||||
}
|
||||
|
||||
uri := model.buildConnectionURI("testpass123")
|
||||
|
||||
assert.Contains(t, uri, "mongodb+srv://")
|
||||
assert.Contains(t, uri, "testuser")
|
||||
assert.Contains(t, uri, "testpass123")
|
||||
assert.Contains(t, uri, "cluster0.example.mongodb.net")
|
||||
assert.Contains(t, uri, "/mydb")
|
||||
assert.Contains(t, uri, "authSource=admin")
|
||||
assert.Contains(t, uri, "connectTimeoutMS=15000")
|
||||
assert.NotContains(t, uri, ":27017")
|
||||
}
|
||||
|
||||
func Test_BuildConnectionURI_WithStandardFormat_ReturnsCorrectUri(t *testing.T) {
|
||||
port := 27017
|
||||
model := &MongodbDatabase{
|
||||
Host: "localhost",
|
||||
Port: &port,
|
||||
Username: "testuser",
|
||||
Password: "testpass123",
|
||||
Database: "mydb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
}
|
||||
|
||||
uri := model.buildConnectionURI("testpass123")
|
||||
|
||||
assert.Contains(t, uri, "mongodb://")
|
||||
assert.Contains(t, uri, "testuser")
|
||||
assert.Contains(t, uri, "testpass123")
|
||||
assert.Contains(t, uri, "localhost:27017")
|
||||
assert.Contains(t, uri, "/mydb")
|
||||
assert.Contains(t, uri, "authSource=admin")
|
||||
assert.Contains(t, uri, "connectTimeoutMS=15000")
|
||||
assert.NotContains(t, uri, "mongodb+srv://")
|
||||
}
|
||||
|
||||
func Test_BuildConnectionURI_WithNullPort_UsesDefault(t *testing.T) {
|
||||
model := &MongodbDatabase{
|
||||
Host: "localhost",
|
||||
Port: nil,
|
||||
Username: "testuser",
|
||||
Password: "testpass123",
|
||||
Database: "mydb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
}
|
||||
|
||||
uri := model.buildConnectionURI("testpass123")
|
||||
|
||||
assert.Contains(t, uri, "localhost:27017")
|
||||
}
|
||||
|
||||
func Test_BuildMongodumpURI_WithSrvFormat_ReturnsCorrectUri(t *testing.T) {
|
||||
port := 27017
|
||||
model := &MongodbDatabase{
|
||||
Host: "cluster0.example.mongodb.net",
|
||||
Port: &port,
|
||||
Username: "testuser",
|
||||
Password: "testpass123",
|
||||
Database: "mydb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: true,
|
||||
}
|
||||
|
||||
uri := model.BuildMongodumpURI("testpass123")
|
||||
|
||||
assert.Contains(t, uri, "mongodb+srv://")
|
||||
assert.Contains(t, uri, "testuser")
|
||||
assert.Contains(t, uri, "testpass123")
|
||||
assert.Contains(t, uri, "cluster0.example.mongodb.net")
|
||||
assert.Contains(t, uri, "/?authSource=admin")
|
||||
assert.Contains(t, uri, "connectTimeoutMS=15000")
|
||||
assert.NotContains(t, uri, ":27017")
|
||||
assert.NotContains(t, uri, "/mydb")
|
||||
}
|
||||
|
||||
func Test_BuildMongodumpURI_WithStandardFormat_ReturnsCorrectUri(t *testing.T) {
|
||||
port := 27017
|
||||
model := &MongodbDatabase{
|
||||
Host: "localhost",
|
||||
Port: &port,
|
||||
Username: "testuser",
|
||||
Password: "testpass123",
|
||||
Database: "mydb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
}
|
||||
|
||||
uri := model.BuildMongodumpURI("testpass123")
|
||||
|
||||
assert.Contains(t, uri, "mongodb://")
|
||||
assert.Contains(t, uri, "testuser")
|
||||
assert.Contains(t, uri, "testpass123")
|
||||
assert.Contains(t, uri, "localhost:27017")
|
||||
assert.Contains(t, uri, "/?authSource=admin")
|
||||
assert.Contains(t, uri, "connectTimeoutMS=15000")
|
||||
assert.NotContains(t, uri, "mongodb+srv://")
|
||||
assert.NotContains(t, uri, "/mydb")
|
||||
}
|
||||
|
||||
func Test_Validate_SrvConnection_AllowsNullPort(t *testing.T) {
|
||||
model := &MongodbDatabase{
|
||||
Host: "cluster0.example.mongodb.net",
|
||||
Port: nil,
|
||||
Username: "testuser",
|
||||
Password: "testpass123",
|
||||
Database: "mydb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: true,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
err := model.Validate()
|
||||
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_StandardConnection_RequiresPort(t *testing.T) {
|
||||
model := &MongodbDatabase{
|
||||
Host: "localhost",
|
||||
Port: nil,
|
||||
Username: "testuser",
|
||||
Password: "testpass123",
|
||||
Database: "mydb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
err := model.Validate()
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "port is required for standard connections")
|
||||
}
|
||||
|
||||
@@ -489,9 +489,13 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
|
||||
hasProcess := false
|
||||
hasAllPrivileges := false
|
||||
|
||||
// Escape underscores to match MySQL's grant output format
|
||||
// MySQL escapes _ as \_ in SHOW GRANTS output
|
||||
// Pattern matches either literal _ or escaped \_
|
||||
escapedDbName := strings.ReplaceAll(regexp.QuoteMeta(database), "_", `(_|\\_)`)
|
||||
dbPatternStr := fmt.Sprintf(
|
||||
`(?i)ON\s+[\x60'"]?%s[\x60'"]?\s*\.\s*\*`,
|
||||
regexp.QuoteMeta(database),
|
||||
escapedDbName,
|
||||
)
|
||||
dbPattern := regexp.MustCompile(dbPatternStr)
|
||||
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\s*\.\s*\*`)
|
||||
|
||||
@@ -674,6 +674,112 @@ func Test_TestConnection_DatabaseWithUnderscores_Success(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_TestConnection_DatabaseWithUnderscoresAndAllPrivileges_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MysqlVersion
|
||||
port string
|
||||
}{
|
||||
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port},
|
||||
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port},
|
||||
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port},
|
||||
{"MySQL 9", tools.MysqlVersion9, env.TestMysql90Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMysqlContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
underscoreDbName := "test_all_db"
|
||||
|
||||
_, err := container.DB.Exec(
|
||||
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", underscoreDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(
|
||||
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
|
||||
)
|
||||
}()
|
||||
|
||||
underscoreDSN := fmt.Sprintf(
|
||||
"%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
container.Username,
|
||||
container.Password,
|
||||
container.Host,
|
||||
container.Port,
|
||||
underscoreDbName,
|
||||
)
|
||||
underscoreDB, err := sqlx.Connect("mysql", underscoreDSN)
|
||||
assert.NoError(t, err)
|
||||
defer underscoreDB.Close()
|
||||
|
||||
_, err = underscoreDB.Exec(`
|
||||
CREATE TABLE all_priv_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(`INSERT INTO all_priv_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
allPrivUsername := fmt.Sprintf("allpriv_%s", uuid.New().String()[:8])
|
||||
allPrivPassword := "allprivpass123"
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
allPrivUsername,
|
||||
allPrivPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"GRANT ALL PRIVILEGES ON `%s`.* TO '%s'@'%%'",
|
||||
underscoreDbName,
|
||||
allPrivUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(
|
||||
fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", allPrivUsername),
|
||||
)
|
||||
}()
|
||||
|
||||
mysqlModel := &MysqlDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: allPrivUsername,
|
||||
Password: allPrivPassword,
|
||||
Database: &underscoreDbName,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mysqlModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, mysqlModel.Privileges)
|
||||
assert.Contains(t, mysqlModel.Privileges, "SELECT")
|
||||
assert.Contains(t, mysqlModel.Privileges, "SHOW VIEW")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type MysqlContainer struct {
|
||||
Host string
|
||||
Port int
|
||||
@@ -694,7 +800,7 @@ func connectToMysqlContainer(
|
||||
}
|
||||
|
||||
dbName := "testdb"
|
||||
host := "127.0.0.1"
|
||||
host := config.GetEnv().TestLocalhost
|
||||
username := "root"
|
||||
password := "rootpassword"
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -90,8 +91,18 @@ func (p *PostgresqlDatabase) Validate() error {
|
||||
// because it would expose internal metadata to non-system administrators.
|
||||
// To properly backup Databasus, see: https://databasus.com/faq#backup-databasus
|
||||
if p.Database != nil && *p.Database != "" {
|
||||
localhostHosts := []string{"localhost", "127.0.0.1", "172.17.0.1", "host.docker.internal"}
|
||||
localhostHosts := []string{
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
"172.17.0.1",
|
||||
"host.docker.internal",
|
||||
"::1", // IPv6 loopback (equivalent to 127.0.0.1)
|
||||
"::", // IPv6 all interfaces (equivalent to 0.0.0.0)
|
||||
"0.0.0.0", // IPv4 all interfaces
|
||||
}
|
||||
|
||||
isLocalhost := false
|
||||
|
||||
for _, host := range localhostHosts {
|
||||
if strings.EqualFold(p.Host, host) {
|
||||
isLocalhost = true
|
||||
@@ -99,6 +110,11 @@ func (p *PostgresqlDatabase) Validate() error {
|
||||
}
|
||||
}
|
||||
|
||||
// Also check if the host is in the entire 127.0.0.0/8 loopback range
|
||||
if strings.HasPrefix(p.Host, "127.") {
|
||||
isLocalhost = true
|
||||
}
|
||||
|
||||
if isLocalhost && strings.EqualFold(*p.Database, "databasus") {
|
||||
return errors.New(
|
||||
"backing up Databasus internal database is not allowed. To backup Databasus itself, see https://databasus.com/faq#backup-databasus",
|
||||
@@ -379,10 +395,13 @@ func (p *PostgresqlDatabase) IsUserReadOnly(
|
||||
//
|
||||
// This method performs the following operations atomically in a single transaction:
|
||||
// 1. Creates a PostgreSQL user with a UUID-based password
|
||||
// 2. Grants CONNECT privilege on the database
|
||||
// 3. Grants USAGE on all non-system schemas
|
||||
// 4. Grants SELECT on all existing tables and sequences
|
||||
// 5. Sets default privileges for future tables and sequences
|
||||
// 2. Revokes CREATE privilege on public schema from PUBLIC role
|
||||
// 3. Grants CONNECT privilege on the database
|
||||
// 4. Discovers all user-created schemas
|
||||
// 5. Grants USAGE on all non-system schemas
|
||||
// 6. Grants SELECT on all existing tables and sequences
|
||||
// 7. Sets default privileges for future tables and sequences
|
||||
// 8. Verifies user creation before committing
|
||||
//
|
||||
// Security features:
|
||||
// - Username format: "databasus-{8-char-uuid}" for uniqueness
|
||||
@@ -472,33 +491,56 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
return "", "", fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
// Step 1.5: Revoke CREATE privilege from PUBLIC role on public schema
|
||||
// Step 2: Check if public schema exists and revoke CREATE privilege if it does
|
||||
// This is necessary because all PostgreSQL users inherit CREATE privilege on the
|
||||
// public schema through the PUBLIC role. This is a one-time operation that affects
|
||||
// the entire database, making it more secure by default.
|
||||
// Note: This only affects the public schema; other schemas are unaffected.
|
||||
_, err = tx.Exec(ctx, `REVOKE CREATE ON SCHEMA public FROM PUBLIC`)
|
||||
if err != nil {
|
||||
logger.Error("Failed to revoke CREATE on public from PUBLIC", "error", err)
|
||||
if !strings.Contains(err.Error(), "schema \"public\" does not exist") &&
|
||||
!strings.Contains(err.Error(), "permission denied") {
|
||||
return "", "", fmt.Errorf("failed to revoke CREATE from PUBLIC: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Now revoke from the specific user as well (belt and suspenders)
|
||||
_, err = tx.Exec(ctx, fmt.Sprintf(`REVOKE CREATE ON SCHEMA public FROM "%s"`, baseUsername))
|
||||
if err != nil {
|
||||
logger.Error(
|
||||
"Failed to revoke CREATE on public schema from user",
|
||||
"error",
|
||||
err,
|
||||
"username",
|
||||
baseUsername,
|
||||
var publicSchemaExists bool
|
||||
err = tx.QueryRow(ctx, `
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM information_schema.schemata
|
||||
WHERE schema_name = 'public'
|
||||
)
|
||||
`).Scan(&publicSchemaExists)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to check if public schema exists: %w", err)
|
||||
}
|
||||
|
||||
// Step 2: Grant database connection privilege and revoke TEMP
|
||||
if publicSchemaExists {
|
||||
// Revoke CREATE from PUBLIC role (affects all users)
|
||||
_, err = tx.Exec(ctx, `REVOKE CREATE ON SCHEMA public FROM PUBLIC`)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "permission denied") {
|
||||
logger.Warn(
|
||||
"Failed to revoke CREATE on public from PUBLIC (permission denied)",
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
} else {
|
||||
return "", "", fmt.Errorf("failed to revoke CREATE from PUBLIC on existing public schema: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Now revoke from the specific user as well (belt and suspenders)
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(`REVOKE CREATE ON SCHEMA public FROM "%s"`, baseUsername),
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn(
|
||||
"Failed to revoke CREATE on public schema from user",
|
||||
"error",
|
||||
err,
|
||||
"username",
|
||||
baseUsername,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
logger.Info("Public schema does not exist, skipping CREATE privilege revocation")
|
||||
}
|
||||
|
||||
// Step 3: Grant database connection privilege and revoke TEMP
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(`GRANT CONNECT ON DATABASE "%s" TO "%s"`, *p.Database, baseUsername),
|
||||
@@ -522,12 +564,23 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
logger.Warn("Failed to revoke TEMP privilege", "error", err, "username", baseUsername)
|
||||
}
|
||||
|
||||
// Step 3: Discover all user-created schemas
|
||||
rows, err := tx.Query(ctx, `
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
|
||||
`)
|
||||
// Step 4: Discover schemas to grant privileges on
|
||||
// If IncludeSchemas is specified, only use those schemas; otherwise use all non-system schemas
|
||||
var rows pgx.Rows
|
||||
if len(p.IncludeSchemas) > 0 {
|
||||
rows, err = tx.Query(ctx, `
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
|
||||
AND schema_name = ANY($1::text[])
|
||||
`, p.IncludeSchemas)
|
||||
} else {
|
||||
rows, err = tx.Query(ctx, `
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
|
||||
`)
|
||||
}
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to get schemas: %w", err)
|
||||
}
|
||||
@@ -547,7 +600,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
return "", "", fmt.Errorf("error iterating schemas: %w", err)
|
||||
}
|
||||
|
||||
// Step 4: Grant USAGE on each schema and explicitly prevent CREATE
|
||||
// Step 5: Grant USAGE on each schema and explicitly prevent CREATE
|
||||
for _, schema := range schemas {
|
||||
// Revoke CREATE specifically (handles inheritance from PUBLIC role)
|
||||
_, err = tx.Exec(
|
||||
@@ -576,51 +629,198 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
}
|
||||
}
|
||||
|
||||
// Step 5: Grant SELECT on ALL existing tables and sequences
|
||||
grantSelectSQL := fmt.Sprintf(`
|
||||
DO $$
|
||||
DECLARE
|
||||
schema_rec RECORD;
|
||||
BEGIN
|
||||
FOR schema_rec IN
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
|
||||
LOOP
|
||||
EXECUTE format('GRANT SELECT ON ALL TABLES IN SCHEMA %%I TO "%s"', schema_rec.schema_name);
|
||||
EXECUTE format('GRANT SELECT ON ALL SEQUENCES IN SCHEMA %%I TO "%s"', schema_rec.schema_name);
|
||||
END LOOP;
|
||||
END $$;
|
||||
`, baseUsername, baseUsername)
|
||||
// Step 6: Grant SELECT on ALL existing tables and sequences
|
||||
// Use the already-filtered schemas list from Step 4
|
||||
for _, schema := range schemas {
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(
|
||||
`GRANT SELECT ON ALL TABLES IN SCHEMA "%s" TO "%s"`,
|
||||
schema,
|
||||
baseUsername,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf(
|
||||
"failed to grant select on tables in schema %s: %w",
|
||||
schema,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(ctx, grantSelectSQL)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to grant select on tables: %w", err)
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(
|
||||
`GRANT SELECT ON ALL SEQUENCES IN SCHEMA "%s" TO "%s"`,
|
||||
schema,
|
||||
baseUsername,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf(
|
||||
"failed to grant select on sequences in schema %s: %w",
|
||||
schema,
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 6: Set default privileges for FUTURE tables and sequences
|
||||
defaultPrivilegesSQL := fmt.Sprintf(`
|
||||
DO $$
|
||||
DECLARE
|
||||
schema_rec RECORD;
|
||||
BEGIN
|
||||
FOR schema_rec IN
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
|
||||
LOOP
|
||||
EXECUTE format('ALTER DEFAULT PRIVILEGES IN SCHEMA %%I GRANT SELECT ON TABLES TO "%s"', schema_rec.schema_name);
|
||||
EXECUTE format('ALTER DEFAULT PRIVILEGES IN SCHEMA %%I GRANT SELECT ON SEQUENCES TO "%s"', schema_rec.schema_name);
|
||||
END LOOP;
|
||||
END $$;
|
||||
`, baseUsername, baseUsername)
|
||||
// Step 7: Set default privileges for FUTURE tables and sequences
|
||||
// First, set default privileges for objects created by the current user
|
||||
// Use the already-filtered schemas list from Step 4
|
||||
for _, schema := range schemas {
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(
|
||||
`ALTER DEFAULT PRIVILEGES IN SCHEMA "%s" GRANT SELECT ON TABLES TO "%s"`,
|
||||
schema,
|
||||
baseUsername,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf(
|
||||
"failed to set default privileges for tables in schema %s: %w",
|
||||
schema,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(ctx, defaultPrivilegesSQL)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to set default privileges: %w", err)
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(
|
||||
`ALTER DEFAULT PRIVILEGES IN SCHEMA "%s" GRANT SELECT ON SEQUENCES TO "%s"`,
|
||||
schema,
|
||||
baseUsername,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf(
|
||||
"failed to set default privileges for sequences in schema %s: %w",
|
||||
schema,
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 7: Verify user creation before committing
|
||||
// Step 8: Discover all roles that own objects in each schema
|
||||
// This is needed because ALTER DEFAULT PRIVILEGES only applies to objects created by the current role.
|
||||
// To handle tables created by OTHER users (like the GitHub issue with partitioned tables),
|
||||
// we need to set "ALTER DEFAULT PRIVILEGES FOR ROLE <owner>" for each object owner.
|
||||
// Filter by IncludeSchemas if specified.
|
||||
type SchemaOwner struct {
|
||||
SchemaName string
|
||||
RoleName string
|
||||
}
|
||||
|
||||
var ownerRows pgx.Rows
|
||||
if len(p.IncludeSchemas) > 0 {
|
||||
ownerRows, err = tx.Query(ctx, `
|
||||
SELECT DISTINCT n.nspname as schema_name, pg_get_userbyid(c.relowner) as role_name
|
||||
FROM pg_class c
|
||||
JOIN pg_namespace n ON c.relnamespace = n.oid
|
||||
WHERE n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
|
||||
AND n.nspname = ANY($1::text[])
|
||||
AND c.relkind IN ('r', 'p', 'v', 'm', 'f')
|
||||
AND pg_get_userbyid(c.relowner) != current_user
|
||||
ORDER BY n.nspname, role_name
|
||||
`, p.IncludeSchemas)
|
||||
} else {
|
||||
ownerRows, err = tx.Query(ctx, `
|
||||
SELECT DISTINCT n.nspname as schema_name, pg_get_userbyid(c.relowner) as role_name
|
||||
FROM pg_class c
|
||||
JOIN pg_namespace n ON c.relnamespace = n.oid
|
||||
WHERE n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
|
||||
AND c.relkind IN ('r', 'p', 'v', 'm', 'f')
|
||||
AND pg_get_userbyid(c.relowner) != current_user
|
||||
ORDER BY n.nspname, role_name
|
||||
`)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// Log warning but continue - this is a best-effort enhancement
|
||||
logger.Warn("Failed to query object owners for default privileges", "error", err)
|
||||
} else {
|
||||
var schemaOwners []SchemaOwner
|
||||
for ownerRows.Next() {
|
||||
var so SchemaOwner
|
||||
if err := ownerRows.Scan(&so.SchemaName, &so.RoleName); err != nil {
|
||||
ownerRows.Close()
|
||||
logger.Warn("Failed to scan schema owner", "error", err)
|
||||
break
|
||||
}
|
||||
schemaOwners = append(schemaOwners, so)
|
||||
}
|
||||
ownerRows.Close()
|
||||
|
||||
if err := ownerRows.Err(); err != nil {
|
||||
logger.Warn("Error iterating schema owners", "error", err)
|
||||
}
|
||||
|
||||
// Step 9: Set default privileges FOR ROLE for each object owner
|
||||
// Note: This may fail for some roles due to permission issues (e.g., roles owned by other superusers)
|
||||
// We log warnings but continue - user creation should succeed even if some roles can't be configured
|
||||
for _, so := range schemaOwners {
|
||||
// Try to set default privileges for tables
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(
|
||||
`ALTER DEFAULT PRIVILEGES FOR ROLE "%s" IN SCHEMA "%s" GRANT SELECT ON TABLES TO "%s"`,
|
||||
so.RoleName,
|
||||
so.SchemaName,
|
||||
baseUsername,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn(
|
||||
"Failed to set default privileges for role (tables)",
|
||||
"error",
|
||||
err,
|
||||
"role",
|
||||
so.RoleName,
|
||||
"schema",
|
||||
so.SchemaName,
|
||||
"readonly_user",
|
||||
baseUsername,
|
||||
)
|
||||
}
|
||||
|
||||
// Try to set default privileges for sequences
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(
|
||||
`ALTER DEFAULT PRIVILEGES FOR ROLE "%s" IN SCHEMA "%s" GRANT SELECT ON SEQUENCES TO "%s"`,
|
||||
so.RoleName,
|
||||
so.SchemaName,
|
||||
baseUsername,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn(
|
||||
"Failed to set default privileges for role (sequences)",
|
||||
"error",
|
||||
err,
|
||||
"role",
|
||||
so.RoleName,
|
||||
"schema",
|
||||
so.SchemaName,
|
||||
"readonly_user",
|
||||
baseUsername,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if len(schemaOwners) > 0 {
|
||||
logger.Info(
|
||||
"Set default privileges for existing object owners",
|
||||
"readonly_user",
|
||||
baseUsername,
|
||||
"owner_count",
|
||||
len(schemaOwners),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 10: Verify user creation before committing
|
||||
var verifyUsername string
|
||||
err = tx.QueryRow(ctx, fmt.Sprintf(`SELECT rolname FROM pg_roles WHERE rolname = '%s'`, baseUsername)).
|
||||
Scan(&verifyUsername)
|
||||
@@ -836,7 +1036,15 @@ func checkBackupPermissions(
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot check SELECT privileges: %w", err)
|
||||
// If the user doesn't have USAGE on the schema, has_table_privilege will fail
|
||||
// with "permission denied for schema". This means they definitely don't have
|
||||
// SELECT privileges, so treat this as missing permissions rather than an error.
|
||||
var pgErr *pgconn.PgError
|
||||
if errors.As(err, &pgErr) && pgErr.Code == "42501" { // insufficient_privilege
|
||||
selectableTableCount = 0
|
||||
} else {
|
||||
return fmt.Errorf("cannot check SELECT privileges: %w", err)
|
||||
}
|
||||
}
|
||||
if selectableTableCount == 0 {
|
||||
missingPrivileges = append(missingPrivileges, "SELECT on tables")
|
||||
|
||||
@@ -599,6 +599,10 @@ func Test_CreateReadOnlyUser_DatabaseNameWithDash_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_Supabase_UserCanReadButNotWrite(t *testing.T) {
|
||||
if config.GetEnv().IsSkipExternalResourcesTests {
|
||||
t.Skip("Skipping Supabase test: IS_SKIP_EXTERNAL_RESOURCES_TESTS is true")
|
||||
}
|
||||
|
||||
env := config.GetEnv()
|
||||
|
||||
if env.TestSupabaseHost == "" {
|
||||
@@ -705,6 +709,344 @@ func Test_CreateReadOnlyUser_Supabase_UserCanReadButNotWrite(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_WithPublicSchema_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
port string
|
||||
}{
|
||||
{"PostgreSQL 12", "12", env.TestPostgres12Port},
|
||||
{"PostgreSQL 13", "13", env.TestPostgres13Port},
|
||||
{"PostgreSQL 14", "14", env.TestPostgres14Port},
|
||||
{"PostgreSQL 15", "15", env.TestPostgres15Port},
|
||||
{"PostgreSQL 16", "16", env.TestPostgres16Port},
|
||||
{"PostgreSQL 17", "17", env.TestPostgres17Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToPostgresContainer(t, tc.port)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`
|
||||
DROP TABLE IF EXISTS public_schema_test CASCADE;
|
||||
CREATE TABLE public_schema_test (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO public_schema_test (data) VALUES ('test1'), ('test2');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, username)
|
||||
assert.NotEmpty(t, password)
|
||||
assert.True(t, strings.HasPrefix(username, "databasus-"))
|
||||
|
||||
readOnlyModel := &PostgresqlDatabase{
|
||||
Version: pgModel.Version,
|
||||
Host: pgModel.Host,
|
||||
Port: pgModel.Port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: pgModel.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
|
||||
ctx,
|
||||
logger,
|
||||
nil,
|
||||
uuid.New(),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, isReadOnly, "User should be read-only")
|
||||
assert.Empty(t, privileges, "Read-only user should have no write privileges")
|
||||
|
||||
readOnlyDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
username,
|
||||
password,
|
||||
container.Database,
|
||||
)
|
||||
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var count int
|
||||
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM public_schema_test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
_, err = readOnlyConn.Exec(
|
||||
"INSERT INTO public_schema_test (data) VALUES ('should-fail')",
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("CREATE TABLE public.hack_table (id INT)")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to drop owned objects: %v", err)
|
||||
}
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_WithoutPublicSchema_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
port string
|
||||
}{
|
||||
{"PostgreSQL 12", "12", env.TestPostgres12Port},
|
||||
{"PostgreSQL 13", "13", env.TestPostgres13Port},
|
||||
{"PostgreSQL 14", "14", env.TestPostgres14Port},
|
||||
{"PostgreSQL 15", "15", env.TestPostgres15Port},
|
||||
{"PostgreSQL 16", "16", env.TestPostgres16Port},
|
||||
{"PostgreSQL 17", "17", env.TestPostgres17Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToPostgresContainer(t, tc.port)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`
|
||||
DROP SCHEMA IF EXISTS public CASCADE;
|
||||
DROP SCHEMA IF EXISTS app_schema CASCADE;
|
||||
DROP SCHEMA IF EXISTS data_schema CASCADE;
|
||||
CREATE SCHEMA app_schema;
|
||||
CREATE SCHEMA data_schema;
|
||||
CREATE TABLE app_schema.users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
username TEXT NOT NULL
|
||||
);
|
||||
CREATE TABLE data_schema.records (
|
||||
id SERIAL PRIMARY KEY,
|
||||
info TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO app_schema.users (username) VALUES ('user1'), ('user2');
|
||||
INSERT INTO data_schema.records (info) VALUES ('record1'), ('record2');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err, "CreateReadOnlyUser should succeed without public schema")
|
||||
assert.NotEmpty(t, username)
|
||||
assert.NotEmpty(t, password)
|
||||
assert.True(t, strings.HasPrefix(username, "databasus-"))
|
||||
|
||||
readOnlyModel := &PostgresqlDatabase{
|
||||
Version: pgModel.Version,
|
||||
Host: pgModel.Host,
|
||||
Port: pgModel.Port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: pgModel.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
|
||||
ctx,
|
||||
logger,
|
||||
nil,
|
||||
uuid.New(),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, isReadOnly, "User should be read-only")
|
||||
assert.Empty(t, privileges, "Read-only user should have no write privileges")
|
||||
|
||||
readOnlyDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
username,
|
||||
password,
|
||||
container.Database,
|
||||
)
|
||||
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var userCount int
|
||||
err = readOnlyConn.Get(&userCount, "SELECT COUNT(*) FROM app_schema.users")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, userCount)
|
||||
|
||||
var recordCount int
|
||||
err = readOnlyConn.Get(&recordCount, "SELECT COUNT(*) FROM data_schema.records")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, recordCount)
|
||||
|
||||
_, err = readOnlyConn.Exec(
|
||||
"INSERT INTO app_schema.users (username) VALUES ('should-fail')",
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("CREATE TABLE app_schema.hack_table (id INT)")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("CREATE TABLE data_schema.hack_table (id INT)")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to drop owned objects: %v", err)
|
||||
}
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`
|
||||
DROP SCHEMA IF EXISTS app_schema CASCADE;
|
||||
DROP SCHEMA IF EXISTS data_schema CASCADE;
|
||||
CREATE SCHEMA IF NOT EXISTS public;
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_PublicSchemaExistsButNoPermissions_ReturnsError(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
port string
|
||||
}{
|
||||
{"PostgreSQL 12", "12", env.TestPostgres12Port},
|
||||
{"PostgreSQL 13", "13", env.TestPostgres13Port},
|
||||
{"PostgreSQL 14", "14", env.TestPostgres14Port},
|
||||
{"PostgreSQL 15", "15", env.TestPostgres15Port},
|
||||
{"PostgreSQL 16", "16", env.TestPostgres16Port},
|
||||
{"PostgreSQL 17", "17", env.TestPostgres17Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToPostgresContainer(t, tc.port)
|
||||
defer container.DB.Close()
|
||||
|
||||
limitedAdminUsername := fmt.Sprintf("limited_admin_%s", uuid.New().String()[:8])
|
||||
limitedAdminPassword := "limited_password_123"
|
||||
|
||||
_, err := container.DB.Exec(`
|
||||
CREATE SCHEMA IF NOT EXISTS public;
|
||||
DROP TABLE IF EXISTS public.permission_test_table CASCADE;
|
||||
CREATE TABLE public.permission_test_table (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO public.permission_test_table (data) VALUES ('test1');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`GRANT CREATE ON SCHEMA public TO PUBLIC`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN CREATEROLE`,
|
||||
limitedAdminUsername,
|
||||
limitedAdminPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
|
||||
container.Database,
|
||||
limitedAdminUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(
|
||||
fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, limitedAdminUsername),
|
||||
)
|
||||
_, _ = container.DB.Exec(
|
||||
fmt.Sprintf(`DROP USER IF EXISTS "%s"`, limitedAdminUsername),
|
||||
)
|
||||
_, _ = container.DB.Exec(`REVOKE CREATE ON SCHEMA public FROM PUBLIC`)
|
||||
}()
|
||||
|
||||
limitedAdminDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
limitedAdminUsername,
|
||||
limitedAdminPassword,
|
||||
container.Database,
|
||||
)
|
||||
limitedAdminConn, err := sqlx.Connect("postgres", limitedAdminDSN)
|
||||
assert.NoError(t, err)
|
||||
defer limitedAdminConn.Close()
|
||||
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Version: tools.GetPostgresqlVersionEnum(tc.version),
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: limitedAdminUsername,
|
||||
Password: limitedAdminPassword,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.Error(
|
||||
t,
|
||||
err,
|
||||
"CreateReadOnlyUser should fail when admin lacks permissions to secure public schema",
|
||||
)
|
||||
if err != nil {
|
||||
errorMsg := err.Error()
|
||||
hasExpectedError := strings.Contains(
|
||||
errorMsg,
|
||||
"failed to revoke CREATE from PUBLIC on existing public schema",
|
||||
) ||
|
||||
strings.Contains(errorMsg, "permission denied for schema public") ||
|
||||
strings.Contains(errorMsg, "failed to grant")
|
||||
assert.True(
|
||||
t,
|
||||
hasExpectedError,
|
||||
"Error should indicate permission issues with public schema, got: %s",
|
||||
errorMsg,
|
||||
)
|
||||
}
|
||||
assert.Empty(t, username)
|
||||
assert.Empty(t, password)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Validate_WhenLocalhostAndDatabasus_ReturnsError(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -754,6 +1096,42 @@ func Test_Validate_WhenLocalhostAndDatabasus_ReturnsError(t *testing.T) {
|
||||
username: "myuser",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "::1 (IPv6 loopback) with databasus db",
|
||||
host: "::1",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: ":: (IPv6 all interfaces) with databasus db",
|
||||
host: "::",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "::1 (uppercase) with DATABASUS db",
|
||||
host: "::1",
|
||||
username: "POSTGRES",
|
||||
database: "DATABASUS",
|
||||
},
|
||||
{
|
||||
name: "0.0.0.0 (all IPv4 interfaces) with databasus db",
|
||||
host: "0.0.0.0",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "127.0.0.2 (loopback range) with databasus db",
|
||||
host: "127.0.0.2",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "127.255.255.255 (end of loopback range) with databasus db",
|
||||
host: "127.255.255.255",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -941,11 +1319,351 @@ type PostgresContainer struct {
|
||||
DB *sqlx.DB
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_TablesCreatedByDifferentUser_ReadOnlyUserCanRead(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToPostgresContainer(t, env.TestPostgres16Port)
|
||||
defer container.DB.Close()
|
||||
|
||||
// Step 1: Create a second database user who will create tables
|
||||
userCreatorUsername := fmt.Sprintf("user_creator_%s", uuid.New().String()[:8])
|
||||
userCreatorPassword := "creator_password_123"
|
||||
|
||||
_, err := container.DB.Exec(fmt.Sprintf(
|
||||
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
|
||||
userCreatorUsername,
|
||||
userCreatorPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, userCreatorUsername))
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, userCreatorUsername))
|
||||
}()
|
||||
|
||||
// Step 2: Grant the user_creator privileges to connect and create tables
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
|
||||
container.Database,
|
||||
userCreatorUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT USAGE ON SCHEMA public TO "%s"`,
|
||||
userCreatorUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT CREATE ON SCHEMA public TO "%s"`,
|
||||
userCreatorUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Step 2b: Create an initial table by user_creator so they become an object owner
|
||||
// This is important because our fix discovers existing object owners
|
||||
userCreatorDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
userCreatorUsername,
|
||||
userCreatorPassword,
|
||||
container.Database,
|
||||
)
|
||||
userCreatorConn, err := sqlx.Connect("postgres", userCreatorDSN)
|
||||
assert.NoError(t, err)
|
||||
defer userCreatorConn.Close()
|
||||
|
||||
initialTableName := fmt.Sprintf(
|
||||
"public.initial_table_%s",
|
||||
strings.ReplaceAll(uuid.New().String()[:8], "-", ""),
|
||||
)
|
||||
_, err = userCreatorConn.Exec(fmt.Sprintf(`
|
||||
CREATE TABLE %s (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO %s (data) VALUES ('initial_data');
|
||||
`, initialTableName, initialTableName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP TABLE IF EXISTS %s CASCADE`, initialTableName))
|
||||
}()
|
||||
|
||||
// Step 3: NOW create read-only user via Databasus (as admin)
|
||||
// At this point, user_creator already owns objects, so ALTER DEFAULT PRIVILEGES FOR ROLE should apply
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
readonlyUsername, readonlyPassword, err := pgModel.CreateReadOnlyUser(
|
||||
ctx,
|
||||
logger,
|
||||
nil,
|
||||
uuid.New(),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, readonlyUsername)
|
||||
assert.NotEmpty(t, readonlyPassword)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, readonlyUsername))
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, readonlyUsername))
|
||||
}()
|
||||
|
||||
// Step 4: user_creator creates a NEW table AFTER the read-only user was created
|
||||
// This table should automatically grant SELECT to the read-only user via ALTER DEFAULT PRIVILEGES FOR ROLE
|
||||
tableName := fmt.Sprintf(
|
||||
"public.future_table_%s",
|
||||
strings.ReplaceAll(uuid.New().String()[:8], "-", ""),
|
||||
)
|
||||
_, err = userCreatorConn.Exec(fmt.Sprintf(`
|
||||
CREATE TABLE %s (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO %s (data) VALUES ('test_data_1'), ('test_data_2');
|
||||
`, tableName, tableName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP TABLE IF EXISTS %s CASCADE`, tableName))
|
||||
}()
|
||||
|
||||
// Step 5: Connect as read-only user and verify it can SELECT from the new table
|
||||
readonlyDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
readonlyUsername,
|
||||
readonlyPassword,
|
||||
container.Database,
|
||||
)
|
||||
readonlyConn, err := sqlx.Connect("postgres", readonlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readonlyConn.Close()
|
||||
|
||||
var count int
|
||||
err = readonlyConn.Get(&count, fmt.Sprintf("SELECT COUNT(*) FROM %s", tableName))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(
|
||||
t,
|
||||
2,
|
||||
count,
|
||||
"Read-only user should be able to SELECT from table created by different user",
|
||||
)
|
||||
|
||||
// Step 6: Verify read-only user cannot write to the table
|
||||
_, err = readonlyConn.Exec(
|
||||
fmt.Sprintf("INSERT INTO %s (data) VALUES ('should-fail')", tableName),
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
// Step 7: Verify pg_dump operations (LOCK TABLE) work
|
||||
// pg_dump needs to lock tables in ACCESS SHARE MODE for consistent backup
|
||||
tx, err := readonlyConn.Begin()
|
||||
assert.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
_, err = tx.Exec(fmt.Sprintf("LOCK TABLE %s IN ACCESS SHARE MODE", tableName))
|
||||
assert.NoError(t, err, "Read-only user should be able to LOCK TABLE (needed for pg_dump)")
|
||||
|
||||
err = tx.Commit()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_WithIncludeSchemas_OnlyGrantsAccessToSpecifiedSchemas(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToPostgresContainer(t, env.TestPostgres16Port)
|
||||
defer container.DB.Close()
|
||||
|
||||
// Step 1: Create multiple schemas and tables
|
||||
_, err := container.DB.Exec(`
|
||||
DROP SCHEMA IF EXISTS included_schema CASCADE;
|
||||
DROP SCHEMA IF EXISTS excluded_schema CASCADE;
|
||||
CREATE SCHEMA included_schema;
|
||||
CREATE SCHEMA excluded_schema;
|
||||
|
||||
CREATE TABLE public.public_table (id INT, data TEXT);
|
||||
INSERT INTO public.public_table VALUES (1, 'public_data');
|
||||
|
||||
CREATE TABLE included_schema.included_table (id INT, data TEXT);
|
||||
INSERT INTO included_schema.included_table VALUES (2, 'included_data');
|
||||
|
||||
CREATE TABLE excluded_schema.excluded_table (id INT, data TEXT);
|
||||
INSERT INTO excluded_schema.excluded_table VALUES (3, 'excluded_data');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(`DROP SCHEMA IF EXISTS included_schema CASCADE`)
|
||||
_, _ = container.DB.Exec(`DROP SCHEMA IF EXISTS excluded_schema CASCADE`)
|
||||
}()
|
||||
|
||||
// Step 2: Create a second user who owns tables in both included and excluded schemas
|
||||
userCreatorUsername := fmt.Sprintf("user_creator_%s", uuid.New().String()[:8])
|
||||
userCreatorPassword := "creator_password_123"
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
|
||||
userCreatorUsername,
|
||||
userCreatorPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, userCreatorUsername))
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, userCreatorUsername))
|
||||
}()
|
||||
|
||||
// Grant privileges to user_creator
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
|
||||
container.Database,
|
||||
userCreatorUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
for _, schema := range []string{"public", "included_schema", "excluded_schema"} {
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT USAGE, CREATE ON SCHEMA %s TO "%s"`,
|
||||
schema,
|
||||
userCreatorUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// User_creator creates tables in included and excluded schemas
|
||||
userCreatorDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
userCreatorUsername,
|
||||
userCreatorPassword,
|
||||
container.Database,
|
||||
)
|
||||
userCreatorConn, err := sqlx.Connect("postgres", userCreatorDSN)
|
||||
assert.NoError(t, err)
|
||||
defer userCreatorConn.Close()
|
||||
|
||||
_, err = userCreatorConn.Exec(`
|
||||
CREATE TABLE included_schema.user_table (id INT, data TEXT);
|
||||
INSERT INTO included_schema.user_table VALUES (4, 'user_included_data');
|
||||
|
||||
CREATE TABLE excluded_schema.user_excluded_table (id INT, data TEXT);
|
||||
INSERT INTO excluded_schema.user_excluded_table VALUES (5, 'user_excluded_data');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Step 3: Create read-only user with IncludeSchemas = ["public", "included_schema"]
|
||||
pgModel := createPostgresModel(container)
|
||||
pgModel.IncludeSchemas = []string{"public", "included_schema"}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
readonlyUsername, readonlyPassword, err := pgModel.CreateReadOnlyUser(
|
||||
ctx,
|
||||
logger,
|
||||
nil,
|
||||
uuid.New(),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, readonlyUsername)
|
||||
assert.NotEmpty(t, readonlyPassword)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, readonlyUsername))
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, readonlyUsername))
|
||||
}()
|
||||
|
||||
// Step 4: Connect as read-only user
|
||||
readonlyDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
readonlyUsername,
|
||||
readonlyPassword,
|
||||
container.Database,
|
||||
)
|
||||
readonlyConn, err := sqlx.Connect("postgres", readonlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readonlyConn.Close()
|
||||
|
||||
// Step 5: Verify read-only user CAN access included schemas
|
||||
var publicData string
|
||||
err = readonlyConn.Get(&publicData, "SELECT data FROM public.public_table LIMIT 1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "public_data", publicData)
|
||||
|
||||
var includedData string
|
||||
err = readonlyConn.Get(&includedData, "SELECT data FROM included_schema.included_table LIMIT 1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "included_data", includedData)
|
||||
|
||||
var userIncludedData string
|
||||
err = readonlyConn.Get(&userIncludedData, "SELECT data FROM included_schema.user_table LIMIT 1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "user_included_data", userIncludedData)
|
||||
|
||||
// Step 6: Verify read-only user CANNOT access excluded schema
|
||||
var excludedData string
|
||||
err = readonlyConn.Get(&excludedData, "SELECT data FROM excluded_schema.excluded_table LIMIT 1")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
err = readonlyConn.Get(
|
||||
&excludedData,
|
||||
"SELECT data FROM excluded_schema.user_excluded_table LIMIT 1",
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
// Step 7: Verify future tables in included schemas are accessible
|
||||
_, err = userCreatorConn.Exec(`
|
||||
CREATE TABLE included_schema.future_table (id INT, data TEXT);
|
||||
INSERT INTO included_schema.future_table VALUES (6, 'future_data');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var futureData string
|
||||
err = readonlyConn.Get(&futureData, "SELECT data FROM included_schema.future_table LIMIT 1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(
|
||||
t,
|
||||
"future_data",
|
||||
futureData,
|
||||
"Read-only user should access future tables in included schemas via ALTER DEFAULT PRIVILEGES FOR ROLE",
|
||||
)
|
||||
|
||||
// Step 8: Verify future tables in excluded schema are NOT accessible
|
||||
_, err = userCreatorConn.Exec(`
|
||||
CREATE TABLE excluded_schema.future_excluded_table (id INT, data TEXT);
|
||||
INSERT INTO excluded_schema.future_excluded_table VALUES (7, 'future_excluded_data');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var futureExcludedData string
|
||||
err = readonlyConn.Get(
|
||||
&futureExcludedData,
|
||||
"SELECT data FROM excluded_schema.future_excluded_table LIMIT 1",
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(
|
||||
t,
|
||||
err.Error(),
|
||||
"permission denied",
|
||||
"Read-only user should NOT access tables in excluded schemas",
|
||||
)
|
||||
}
|
||||
|
||||
func connectToPostgresContainer(t *testing.T, port string) *PostgresContainer {
|
||||
dbName := "testdb"
|
||||
password := "testpassword"
|
||||
username := "testuser"
|
||||
host := "localhost"
|
||||
host := config.GetEnv().TestLocalhost
|
||||
|
||||
portInt, err := strconv.Atoi(port)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package databases
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
@@ -37,7 +40,22 @@ func GetDatabaseController() *DatabaseController {
|
||||
return databaseController
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(databaseService)
|
||||
notifiers.GetNotifierService().SetNotifierDatabaseCounter(databaseService)
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(databaseService)
|
||||
notifiers.GetNotifierService().SetNotifierDatabaseCounter(databaseService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package databases
|
||||
|
||||
import (
|
||||
"context"
|
||||
"databasus-backend/internal/features/databases/databases/mariadb"
|
||||
"databasus-backend/internal/features/databases/databases/mongodb"
|
||||
"databasus-backend/internal/features/databases/databases/mysql"
|
||||
@@ -84,6 +85,25 @@ func (d *Database) TestConnection(
|
||||
return d.getSpecificDatabase().TestConnection(logger, encryptor, d.ID)
|
||||
}
|
||||
|
||||
func (d *Database) IsUserReadOnly(
|
||||
ctx context.Context,
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
) (bool, []string, error) {
|
||||
switch d.Type {
|
||||
case DatabaseTypePostgres:
|
||||
return d.Postgresql.IsUserReadOnly(ctx, logger, encryptor, d.ID)
|
||||
case DatabaseTypeMysql:
|
||||
return d.Mysql.IsUserReadOnly(ctx, logger, encryptor, d.ID)
|
||||
case DatabaseTypeMariadb:
|
||||
return d.Mariadb.IsUserReadOnly(ctx, logger, encryptor, d.ID)
|
||||
case DatabaseTypeMongodb:
|
||||
return d.Mongodb.IsUserReadOnly(ctx, logger, encryptor, d.ID)
|
||||
default:
|
||||
return false, nil, errors.New("read-only check not supported for this database type")
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Database) HideSensitiveData() {
|
||||
d.getSpecificDatabase().HideSensitiveData()
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/databases/databases/mariadb"
|
||||
"databasus-backend/internal/features/databases/databases/mongodb"
|
||||
@@ -86,6 +87,23 @@ func (s *DatabaseService) CreateDatabase(
|
||||
return nil, fmt.Errorf("failed to auto-detect database data: %w", err)
|
||||
}
|
||||
|
||||
if config.GetEnv().IsCloud {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
isReadOnly, permissions, err := database.IsUserReadOnly(ctx, s.logger, s.fieldEncryptor)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to verify user permissions: %w", err)
|
||||
}
|
||||
|
||||
if !isReadOnly {
|
||||
return nil, fmt.Errorf(
|
||||
"in cloud mode, only read-only database users are allowed (user has permissions: %v)",
|
||||
permissions,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if err := database.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt sensitive fields: %w", err)
|
||||
}
|
||||
@@ -153,6 +171,29 @@ func (s *DatabaseService) UpdateDatabase(
|
||||
return fmt.Errorf("failed to auto-detect database data: %w", err)
|
||||
}
|
||||
|
||||
if config.GetEnv().IsCloud {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
isReadOnly, permissions, err := existingDatabase.IsUserReadOnly(
|
||||
ctx,
|
||||
s.logger,
|
||||
s.fieldEncryptor,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to verify user permissions: %w", err)
|
||||
}
|
||||
|
||||
if !isReadOnly {
|
||||
return fmt.Errorf(
|
||||
"in cloud mode, only read-only database users are allowed (user has permissions: %v)",
|
||||
permissions,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
oldName := existingDatabase.Name
|
||||
|
||||
if err := existingDatabase.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
|
||||
return fmt.Errorf("failed to encrypt sensitive fields: %w", err)
|
||||
}
|
||||
@@ -162,11 +203,23 @@ func (s *DatabaseService) UpdateDatabase(
|
||||
return err
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Database updated: %s", existingDatabase.Name),
|
||||
&user.ID,
|
||||
existingDatabase.WorkspaceID,
|
||||
)
|
||||
if oldName != existingDatabase.Name {
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
"Database updated and renamed from '%s' to '%s'",
|
||||
oldName,
|
||||
existingDatabase.Name,
|
||||
),
|
||||
&user.ID,
|
||||
existingDatabase.WorkspaceID,
|
||||
)
|
||||
} else {
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Database updated: %s", existingDatabase.Name),
|
||||
&user.ID,
|
||||
existingDatabase.WorkspaceID,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -532,9 +585,19 @@ func (s *DatabaseService) TransferDatabaseToWorkspace(
|
||||
return err
|
||||
}
|
||||
|
||||
sourceWorkspace, err := s.workspaceService.GetWorkspaceByID(*sourceWorkspaceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get source workspace: %w", err)
|
||||
}
|
||||
|
||||
targetWorkspace, err := s.workspaceService.GetWorkspaceByID(targetWorkspaceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get target workspace: %w", err)
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Database transferred: %s from workspace %s to workspace %s",
|
||||
database.Name, sourceWorkspaceID, targetWorkspaceID),
|
||||
fmt.Sprintf("Database transferred: %s from workspace '%s' to workspace '%s'",
|
||||
database.Name, sourceWorkspace.Name, targetWorkspace.Name),
|
||||
nil,
|
||||
&targetWorkspaceID,
|
||||
)
|
||||
@@ -649,38 +712,7 @@ func (s *DatabaseService) IsUserReadOnly(
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
switch usingDatabase.Type {
|
||||
case DatabaseTypePostgres:
|
||||
return usingDatabase.Postgresql.IsUserReadOnly(
|
||||
ctx,
|
||||
s.logger,
|
||||
s.fieldEncryptor,
|
||||
usingDatabase.ID,
|
||||
)
|
||||
case DatabaseTypeMysql:
|
||||
return usingDatabase.Mysql.IsUserReadOnly(
|
||||
ctx,
|
||||
s.logger,
|
||||
s.fieldEncryptor,
|
||||
usingDatabase.ID,
|
||||
)
|
||||
case DatabaseTypeMariadb:
|
||||
return usingDatabase.Mariadb.IsUserReadOnly(
|
||||
ctx,
|
||||
s.logger,
|
||||
s.fieldEncryptor,
|
||||
usingDatabase.ID,
|
||||
)
|
||||
case DatabaseTypeMongodb:
|
||||
return usingDatabase.Mongodb.IsUserReadOnly(
|
||||
ctx,
|
||||
s.logger,
|
||||
s.fieldEncryptor,
|
||||
usingDatabase.ID,
|
||||
)
|
||||
default:
|
||||
return false, nil, errors.New("read-only check not supported for this database type")
|
||||
}
|
||||
return usingDatabase.IsUserReadOnly(ctx, s.logger, s.fieldEncryptor)
|
||||
}
|
||||
|
||||
func (s *DatabaseService) CreateReadOnlyUser(
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
"databasus-backend/internal/storage"
|
||||
"databasus-backend/internal/util/tools"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -25,7 +26,7 @@ func GetTestPostgresConfig() *postgresql.PostgresqlDatabase {
|
||||
testDbName := "testdb"
|
||||
return &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
@@ -48,7 +49,7 @@ func GetTestMariadbConfig() *mariadb.MariadbDatabase {
|
||||
testDbName := "testdb"
|
||||
return &mariadb.MariadbDatabase{
|
||||
Version: tools.MariadbVersion1011,
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
@@ -69,13 +70,14 @@ func GetTestMongodbConfig() *mongodb.MongodbDatabase {
|
||||
|
||||
return &mongodb.MongodbDatabase{
|
||||
Version: tools.MongodbVersion7,
|
||||
Host: "localhost",
|
||||
Port: port,
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: &port,
|
||||
Username: "root",
|
||||
Password: "rootpassword",
|
||||
Database: "testdb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
}
|
||||
@@ -104,6 +106,19 @@ func CreateTestDatabase(
|
||||
}
|
||||
|
||||
func RemoveTestDatabase(database *Database) {
|
||||
// Delete backups and backup configs associated with this database
|
||||
// We hardcode SQL here because we cannot call backups feature due to DI inversion
|
||||
// (databases package cannot import backups package as backups already imports databases)
|
||||
db := storage.GetDb()
|
||||
|
||||
if err := db.Exec("DELETE FROM backups WHERE database_id = ?", database.ID).Error; err != nil {
|
||||
panic(fmt.Sprintf("failed to delete backups: %v", err))
|
||||
}
|
||||
|
||||
if err := db.Exec("DELETE FROM backup_configs WHERE database_id = ?", database.ID).Error; err != nil {
|
||||
panic(fmt.Sprintf("failed to delete backup config: %v", err))
|
||||
}
|
||||
|
||||
err := databaseRepository.Delete(database.ID)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
||||
@@ -12,6 +12,15 @@ import (
|
||||
type DiskService struct{}
|
||||
|
||||
func (s *DiskService) GetDiskUsage() (*DiskUsage, error) {
|
||||
if config.GetEnv().IsCloud {
|
||||
return &DiskUsage{
|
||||
Platform: PlatformLinux,
|
||||
TotalSpaceBytes: 100,
|
||||
UsedSpaceBytes: 0,
|
||||
FreeSpaceBytes: 100,
|
||||
}, nil
|
||||
}
|
||||
|
||||
platform := s.detectPlatform()
|
||||
|
||||
var path string
|
||||
|
||||
22
backend/internal/features/email/di.go
Normal file
22
backend/internal/features/email/di.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var env = config.GetEnv()
|
||||
var log = logger.GetLogger()
|
||||
|
||||
var emailSMTPSender = &EmailSMTPSender{
|
||||
log,
|
||||
env.SMTPHost,
|
||||
env.SMTPPort,
|
||||
env.SMTPUser,
|
||||
env.SMTPPassword,
|
||||
env.SMTPHost != "" && env.SMTPPort != 0,
|
||||
}
|
||||
|
||||
func GetEmailSMTPSender() *EmailSMTPSender {
|
||||
return emailSMTPSender
|
||||
}
|
||||
245
backend/internal/features/email/email.go
Normal file
245
backend/internal/features/email/email.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"mime"
|
||||
"net"
|
||||
"net/smtp"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
ImplicitTLSPort = 465
|
||||
DefaultTimeout = 5 * time.Second
|
||||
DefaultHelloName = "localhost"
|
||||
MIMETypeHTML = "text/html"
|
||||
MIMECharsetUTF8 = "UTF-8"
|
||||
)
|
||||
|
||||
type EmailSMTPSender struct {
|
||||
logger *slog.Logger
|
||||
smtpHost string
|
||||
smtpPort int
|
||||
smtpUser string
|
||||
smtpPassword string
|
||||
isConfigured bool
|
||||
}
|
||||
|
||||
func (s *EmailSMTPSender) SendEmail(to, subject, body string) error {
|
||||
if !s.isConfigured {
|
||||
s.logger.Warn("Skipping email send, SMTP not initialized", "to", to, "subject", subject)
|
||||
return nil
|
||||
}
|
||||
|
||||
from := s.smtpUser
|
||||
if from == "" {
|
||||
from = "noreply@" + s.smtpHost
|
||||
}
|
||||
|
||||
emailContent := s.buildEmailContent(to, subject, body, from)
|
||||
isAuthRequired := s.smtpUser != "" && s.smtpPassword != ""
|
||||
|
||||
if s.smtpPort == ImplicitTLSPort {
|
||||
return s.sendImplicitTLS(to, from, emailContent, isAuthRequired)
|
||||
}
|
||||
|
||||
return s.sendStartTLS(to, from, emailContent, isAuthRequired)
|
||||
}
|
||||
|
||||
func (s *EmailSMTPSender) buildEmailContent(to, subject, body, from string) []byte {
|
||||
// Encode Subject header using RFC 2047 to avoid SMTPUTF8 requirement
|
||||
encodedSubject := encodeRFC2047(subject)
|
||||
subjectHeader := fmt.Sprintf("Subject: %s\r\n", encodedSubject)
|
||||
dateHeader := fmt.Sprintf("Date: %s\r\n", time.Now().UTC().Format(time.RFC1123Z))
|
||||
|
||||
mimeHeaders := fmt.Sprintf(
|
||||
"MIME-version: 1.0;\nContent-Type: %s; charset=\"%s\";\n\n",
|
||||
MIMETypeHTML,
|
||||
MIMECharsetUTF8,
|
||||
)
|
||||
|
||||
// Encode From header display name if it contains non-ASCII
|
||||
encodedFrom := encodeRFC2047(from)
|
||||
fromHeader := fmt.Sprintf("From: %s\r\n", encodedFrom)
|
||||
|
||||
toHeader := fmt.Sprintf("To: %s\r\n", to)
|
||||
|
||||
return []byte(fromHeader + toHeader + subjectHeader + dateHeader + mimeHeaders + body)
|
||||
}
|
||||
|
||||
func (s *EmailSMTPSender) sendImplicitTLS(
|
||||
to, from string,
|
||||
emailContent []byte,
|
||||
isAuthRequired bool,
|
||||
) error {
|
||||
createClient := func() (*smtp.Client, func(), error) {
|
||||
return s.createImplicitTLSClient()
|
||||
}
|
||||
|
||||
client, cleanup, err := s.authenticateWithRetry(createClient, isAuthRequired)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
return s.sendEmail(client, to, from, emailContent)
|
||||
}
|
||||
|
||||
func (s *EmailSMTPSender) sendStartTLS(
|
||||
to, from string,
|
||||
emailContent []byte,
|
||||
isAuthRequired bool,
|
||||
) error {
|
||||
createClient := func() (*smtp.Client, func(), error) {
|
||||
return s.createStartTLSClient()
|
||||
}
|
||||
|
||||
client, cleanup, err := s.authenticateWithRetry(createClient, isAuthRequired)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
return s.sendEmail(client, to, from, emailContent)
|
||||
}
|
||||
|
||||
func (s *EmailSMTPSender) createImplicitTLSClient() (*smtp.Client, func(), error) {
|
||||
addr := net.JoinHostPort(s.smtpHost, fmt.Sprintf("%d", s.smtpPort))
|
||||
tlsConfig := &tls.Config{ServerName: s.smtpHost}
|
||||
dialer := &net.Dialer{Timeout: DefaultTimeout}
|
||||
|
||||
conn, err := tls.DialWithDialer(dialer, "tcp", addr, tlsConfig)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to connect to SMTP server: %w", err)
|
||||
}
|
||||
|
||||
client, err := smtp.NewClient(conn, s.smtpHost)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, nil, fmt.Errorf("failed to create SMTP client: %w", err)
|
||||
}
|
||||
|
||||
return client, func() { _ = client.Quit() }, nil
|
||||
}
|
||||
|
||||
func (s *EmailSMTPSender) createStartTLSClient() (*smtp.Client, func(), error) {
|
||||
addr := net.JoinHostPort(s.smtpHost, fmt.Sprintf("%d", s.smtpPort))
|
||||
dialer := &net.Dialer{Timeout: DefaultTimeout}
|
||||
|
||||
conn, err := dialer.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to connect to SMTP server: %w", err)
|
||||
}
|
||||
|
||||
client, err := smtp.NewClient(conn, s.smtpHost)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, nil, fmt.Errorf("failed to create SMTP client: %w", err)
|
||||
}
|
||||
|
||||
if err := client.Hello(DefaultHelloName); err != nil {
|
||||
_ = client.Quit()
|
||||
_ = conn.Close()
|
||||
return nil, nil, fmt.Errorf("SMTP hello failed: %w", err)
|
||||
}
|
||||
|
||||
if ok, _ := client.Extension("STARTTLS"); ok {
|
||||
if err := client.StartTLS(&tls.Config{ServerName: s.smtpHost}); err != nil {
|
||||
_ = client.Quit()
|
||||
_ = conn.Close()
|
||||
return nil, nil, fmt.Errorf("STARTTLS failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return client, func() { _ = client.Quit() }, nil
|
||||
}
|
||||
|
||||
func (s *EmailSMTPSender) authenticateWithRetry(
|
||||
createClient func() (*smtp.Client, func(), error),
|
||||
isAuthRequired bool,
|
||||
) (*smtp.Client, func(), error) {
|
||||
client, cleanup, err := createClient()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if !isAuthRequired {
|
||||
return client, cleanup, nil
|
||||
}
|
||||
|
||||
// Try PLAIN auth first
|
||||
plainAuth := smtp.PlainAuth("", s.smtpUser, s.smtpPassword, s.smtpHost)
|
||||
if err := client.Auth(plainAuth); err == nil {
|
||||
return client, cleanup, nil
|
||||
}
|
||||
|
||||
// PLAIN auth failed, connection may be closed - recreate and try LOGIN auth
|
||||
cleanup()
|
||||
|
||||
client, cleanup, err = createClient()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
loginAuth := &loginAuth{username: s.smtpUser, password: s.smtpPassword}
|
||||
if err := client.Auth(loginAuth); err != nil {
|
||||
cleanup()
|
||||
return nil, nil, fmt.Errorf("SMTP authentication failed: %w", err)
|
||||
}
|
||||
|
||||
return client, cleanup, nil
|
||||
}
|
||||
|
||||
func (s *EmailSMTPSender) sendEmail(client *smtp.Client, to, from string, content []byte) error {
|
||||
if err := client.Mail(from); err != nil {
|
||||
return fmt.Errorf("failed to set sender: %w", err)
|
||||
}
|
||||
|
||||
if err := client.Rcpt(to); err != nil {
|
||||
return fmt.Errorf("failed to set recipient: %w", err)
|
||||
}
|
||||
|
||||
writer, err := client.Data()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get data writer: %w", err)
|
||||
}
|
||||
|
||||
if _, err = writer.Write(content); err != nil {
|
||||
return fmt.Errorf("failed to write email content: %w", err)
|
||||
}
|
||||
|
||||
if err = writer.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close data writer: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func encodeRFC2047(s string) string {
|
||||
return mime.QEncoding.Encode("UTF-8", s)
|
||||
}
|
||||
|
||||
type loginAuth struct {
|
||||
username string
|
||||
password string
|
||||
}
|
||||
|
||||
func (a *loginAuth) Start(server *smtp.ServerInfo) (string, []byte, error) {
|
||||
return "LOGIN", []byte{}, nil
|
||||
}
|
||||
|
||||
func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) {
|
||||
if more {
|
||||
switch string(fromServer) {
|
||||
case "Username:", "User Name\x00":
|
||||
return []byte(a.username), nil
|
||||
case "Password:", "Password\x00":
|
||||
return []byte(a.password), nil
|
||||
default:
|
||||
return []byte(a.username), nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
@@ -2,30 +2,47 @@ package healthcheck_attempt
|
||||
|
||||
import (
|
||||
"context"
|
||||
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
|
||||
)
|
||||
|
||||
type HealthcheckAttemptBackgroundService struct {
|
||||
healthcheckConfigService *healthcheck_config.HealthcheckConfigService
|
||||
checkDatabaseHealthUseCase *CheckDatabaseHealthUseCase
|
||||
logger *slog.Logger
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *HealthcheckAttemptBackgroundService) Run(ctx context.Context) {
|
||||
// first healthcheck immediately
|
||||
s.checkDatabases()
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.checkDatabases()
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
|
||||
// first healthcheck immediately
|
||||
s.checkDatabases()
|
||||
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.checkDatabases()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -144,6 +144,10 @@ func Test_GetAttemptsByDatabase_PermissionsEnforced(t *testing.T) {
|
||||
)
|
||||
assert.Contains(t, string(testResp.Body), "forbidden")
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -181,6 +185,10 @@ func Test_GetAttemptsByDatabase_FiltersByAfterDate(t *testing.T) {
|
||||
for _, attempt := range response {
|
||||
assert.True(t, attempt.CreatedAt.After(afterDate) || attempt.CreatedAt.Equal(afterDate))
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_GetAttemptsByDatabase_ReturnsEmptyListForNewDatabase(t *testing.T) {
|
||||
@@ -201,6 +209,10 @@ func Test_GetAttemptsByDatabase_ReturnsEmptyListForNewDatabase(t *testing.T) {
|
||||
)
|
||||
|
||||
assert.Equal(t, 0, len(response))
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func createTestDatabaseViaAPI(
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package healthcheck_attempt
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"databasus-backend/internal/features/databases"
|
||||
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
@@ -22,9 +25,11 @@ var checkDatabaseHealthUseCase = &CheckDatabaseHealthUseCase{
|
||||
}
|
||||
|
||||
var healthcheckAttemptBackgroundService = &HealthcheckAttemptBackgroundService{
|
||||
healthcheck_config.GetHealthcheckConfigService(),
|
||||
checkDatabaseHealthUseCase,
|
||||
logger.GetLogger(),
|
||||
healthcheckConfigService: healthcheck_config.GetHealthcheckConfigService(),
|
||||
checkDatabaseHealthUseCase: checkDatabaseHealthUseCase,
|
||||
logger: logger.GetLogger(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
var healthcheckAttemptController = &HealthcheckAttemptController{
|
||||
healthcheckAttemptService,
|
||||
|
||||
@@ -130,6 +130,10 @@ func Test_SaveHealthcheckConfig_PermissionsEnforced(t *testing.T) {
|
||||
)
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -162,6 +166,10 @@ func Test_SaveHealthcheckConfig_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_GetHealthcheckConfig_PermissionsEnforced(t *testing.T) {
|
||||
@@ -268,6 +276,10 @@ func Test_GetHealthcheckConfig_PermissionsEnforced(t *testing.T) {
|
||||
)
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -295,6 +307,10 @@ func Test_GetHealthcheckConfig_ReturnsDefaultConfigForNewDatabase(t *testing.T)
|
||||
assert.Equal(t, 1, response.IntervalMinutes)
|
||||
assert.Equal(t, 3, response.AttemptsBeforeConcideredAsDown)
|
||||
assert.Equal(t, 7, response.StoreAttemptsDays)
|
||||
|
||||
// Cleanup
|
||||
databases.RemoveTestDatabase(database)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func createTestDatabaseViaAPI(
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package healthcheck_config
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/databases"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
@@ -27,8 +30,23 @@ func GetHealthcheckConfigController() *HealthcheckConfigController {
|
||||
return healthcheckConfigController
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
databases.
|
||||
GetDatabaseService().
|
||||
AddDbCreationListener(healthcheckConfigService)
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
databases.
|
||||
GetDatabaseService().
|
||||
AddDbCreationListener(healthcheckConfigService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package notifiers
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
@@ -32,6 +35,22 @@ func GetNotifierService() *NotifierService {
|
||||
func GetNotifierRepository() *NotifierRepository {
|
||||
return notifierRepository
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(notifierService)
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(notifierService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"mime"
|
||||
"net"
|
||||
"net/smtp"
|
||||
"time"
|
||||
@@ -115,16 +116,35 @@ func (e *EmailNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeRFC2047 encodes a string using RFC 2047 MIME encoding for email headers
|
||||
// This ensures compatibility with SMTP servers that don't support SMTPUTF8
|
||||
func encodeRFC2047(s string) string {
|
||||
// mime.QEncoding handles UTF-8 → =?UTF-8?Q?...?= encoding
|
||||
// This allows non-ASCII characters (emojis, accents, etc.) in email headers
|
||||
// while maintaining compatibility with all SMTP servers
|
||||
return mime.QEncoding.Encode("UTF-8", s)
|
||||
}
|
||||
|
||||
func (e *EmailNotifier) buildEmailContent(heading, message, from string) []byte {
|
||||
subject := fmt.Sprintf("Subject: %s\r\n", heading)
|
||||
mime := fmt.Sprintf(
|
||||
// Encode Subject header using RFC 2047 to avoid SMTPUTF8 requirement
|
||||
// This ensures compatibility with SMTP servers that don't support SMTPUTF8
|
||||
encodedSubject := encodeRFC2047(heading)
|
||||
subject := fmt.Sprintf("Subject: %s\r\n", encodedSubject)
|
||||
dateHeader := fmt.Sprintf("Date: %s\r\n", time.Now().UTC().Format(time.RFC1123Z))
|
||||
|
||||
mimeHeaders := fmt.Sprintf(
|
||||
"MIME-version: 1.0;\nContent-Type: %s; charset=\"%s\";\n\n",
|
||||
MIMETypeHTML,
|
||||
MIMECharsetUTF8,
|
||||
)
|
||||
fromHeader := fmt.Sprintf("From: %s\r\n", from)
|
||||
|
||||
// Encode From header display name if it contains non-ASCII
|
||||
encodedFrom := encodeRFC2047(from)
|
||||
fromHeader := fmt.Sprintf("From: %s\r\n", encodedFrom)
|
||||
|
||||
toHeader := fmt.Sprintf("To: %s\r\n", e.TargetEmail)
|
||||
return []byte(fromHeader + toHeader + subject + mime + message)
|
||||
|
||||
return []byte(fromHeader + toHeader + subject + dateHeader + mimeHeaders + message)
|
||||
}
|
||||
|
||||
func (e *EmailNotifier) sendImplicitTLS(
|
||||
|
||||
@@ -58,6 +58,8 @@ func (s *NotifierService) SaveNotifier(
|
||||
return err
|
||||
}
|
||||
|
||||
oldName := existingNotifier.Name
|
||||
|
||||
if err := existingNotifier.Validate(s.fieldEncryptor); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -67,11 +69,23 @@ func (s *NotifierService) SaveNotifier(
|
||||
return err
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Notifier updated: %s", existingNotifier.Name),
|
||||
&user.ID,
|
||||
&workspaceID,
|
||||
)
|
||||
if oldName != existingNotifier.Name {
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
"Notifier updated and renamed from '%s' to '%s'",
|
||||
oldName,
|
||||
existingNotifier.Name,
|
||||
),
|
||||
&user.ID,
|
||||
&workspaceID,
|
||||
)
|
||||
} else {
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Notifier updated: %s", existingNotifier.Name),
|
||||
&user.ID,
|
||||
&workspaceID,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
notifier.WorkspaceID = workspaceID
|
||||
|
||||
@@ -343,9 +357,19 @@ func (s *NotifierService) TransferNotifierToWorkspace(
|
||||
return err
|
||||
}
|
||||
|
||||
sourceWorkspace, err := s.workspaceService.GetWorkspaceByID(sourceWorkspaceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get source workspace: %w", err)
|
||||
}
|
||||
|
||||
targetWorkspace, err := s.workspaceService.GetWorkspaceByID(targetWorkspaceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get target workspace: %w", err)
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Notifier transferred: %s from workspace %s to workspace %s",
|
||||
existingNotifier.Name, sourceWorkspaceID, targetWorkspaceID),
|
||||
fmt.Sprintf("Notifier transferred: %s from workspace '%s' to workspace '%s'",
|
||||
existingNotifier.Name, sourceWorkspace.Name, targetWorkspace.Name),
|
||||
&user.ID,
|
||||
&targetWorkspaceID,
|
||||
)
|
||||
|
||||
20
backend/internal/features/plan/di.go
Normal file
20
backend/internal/features/plan/di.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package plans
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var databasePlanRepository = &DatabasePlanRepository{}
|
||||
|
||||
var databasePlanService = &DatabasePlanService{
|
||||
databasePlanRepository,
|
||||
logger.GetLogger(),
|
||||
}
|
||||
|
||||
func GetDatabasePlanService() *DatabasePlanService {
|
||||
return databasePlanService
|
||||
}
|
||||
|
||||
func GetDatabasePlanRepository() *DatabasePlanRepository {
|
||||
return databasePlanRepository
|
||||
}
|
||||
19
backend/internal/features/plan/model.go
Normal file
19
backend/internal/features/plan/model.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package plans
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/util/period"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type DatabasePlan struct {
|
||||
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;primaryKey;not null"`
|
||||
|
||||
MaxBackupSizeMB int64 `json:"maxBackupSizeMb" gorm:"column:max_backup_size_mb;type:int;not null"`
|
||||
MaxBackupsTotalSizeMB int64 `json:"maxBackupsTotalSizeMb" gorm:"column:max_backups_total_size_mb;type:int;not null"`
|
||||
MaxStoragePeriod period.TimePeriod `json:"maxStoragePeriod" gorm:"column:max_storage_period;type:text;not null"`
|
||||
}
|
||||
|
||||
func (p *DatabasePlan) TableName() string {
|
||||
return "database_plans"
|
||||
}
|
||||
27
backend/internal/features/plan/repository.go
Normal file
27
backend/internal/features/plan/repository.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package plans
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/storage"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type DatabasePlanRepository struct{}
|
||||
|
||||
func (r *DatabasePlanRepository) GetDatabasePlan(databaseID uuid.UUID) (*DatabasePlan, error) {
|
||||
var databasePlan DatabasePlan
|
||||
|
||||
if err := storage.GetDb().Where("database_id = ?", databaseID).First(&databasePlan).Error; err != nil {
|
||||
if err.Error() == "record not found" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &databasePlan, nil
|
||||
}
|
||||
|
||||
func (r *DatabasePlanRepository) CreateDatabasePlan(databasePlan *DatabasePlan) error {
|
||||
return storage.GetDb().Create(&databasePlan).Error
|
||||
}
|
||||
67
backend/internal/features/plan/service.go
Normal file
67
backend/internal/features/plan/service.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package plans
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/util/period"
|
||||
"log/slog"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type DatabasePlanService struct {
|
||||
databasePlanRepository *DatabasePlanRepository
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *DatabasePlanService) GetDatabasePlan(databaseID uuid.UUID) (*DatabasePlan, error) {
|
||||
plan, err := s.databasePlanRepository.GetDatabasePlan(databaseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if plan == nil {
|
||||
s.logger.Info("no database plan found, creating default plan", "databaseID", databaseID)
|
||||
|
||||
defaultPlan := s.createDefaultDatabasePlan(databaseID)
|
||||
|
||||
err := s.databasePlanRepository.CreateDatabasePlan(defaultPlan)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to create default database plan", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return defaultPlan, nil
|
||||
}
|
||||
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
func (s *DatabasePlanService) createDefaultDatabasePlan(databaseID uuid.UUID) *DatabasePlan {
|
||||
var plan DatabasePlan
|
||||
|
||||
isCloud := config.GetEnv().IsCloud
|
||||
if isCloud {
|
||||
s.logger.Info("creating default database plan for cloud", "databaseID", databaseID)
|
||||
|
||||
// for playground we set limited storages enough to test,
|
||||
// but not too expensive to provide it for Databasus
|
||||
plan = DatabasePlan{
|
||||
DatabaseID: databaseID,
|
||||
MaxBackupSizeMB: 100, // ~ 1.5GB database
|
||||
MaxBackupsTotalSizeMB: 4000, // ~ 30 daily backups + 10 manual backups
|
||||
MaxStoragePeriod: period.PeriodWeek,
|
||||
}
|
||||
} else {
|
||||
s.logger.Info("creating default database plan for self hosted", "databaseID", databaseID)
|
||||
|
||||
// by default - everything is unlimited in self hosted mode
|
||||
plan = DatabasePlan{
|
||||
DatabaseID: databaseID,
|
||||
MaxBackupSizeMB: 0,
|
||||
MaxBackupsTotalSizeMB: 0,
|
||||
MaxStoragePeriod: period.PeriodForever,
|
||||
}
|
||||
}
|
||||
|
||||
return &plan
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
package restores
|
||||
|
||||
import (
|
||||
"context"
|
||||
"databasus-backend/internal/features/restores/enums"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
type RestoreBackgroundService struct {
|
||||
restoreRepository *RestoreRepository
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *RestoreBackgroundService) Run(ctx context.Context) {
|
||||
if err := s.failRestoresInProgress(); err != nil {
|
||||
s.logger.Error("Failed to fail restores in progress", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RestoreBackgroundService) failRestoresInProgress() error {
|
||||
restoresInProgress, err := s.restoreRepository.FindByStatus(enums.RestoreStatusInProgress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, restore := range restoresInProgress {
|
||||
failMessage := "Restore failed due to application restart"
|
||||
restore.Status = enums.RestoreStatusFailed
|
||||
restore.FailMessage = &failMessage
|
||||
|
||||
if err := s.restoreRepository.Save(restore); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package restores
|
||||
|
||||
import (
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
"net/http"
|
||||
|
||||
@@ -15,6 +16,7 @@ type RestoreController struct {
|
||||
func (c *RestoreController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
router.GET("/restores/:backupId", c.GetRestores)
|
||||
router.POST("/restores/:backupId/restore", c.RestoreBackup)
|
||||
router.POST("/restores/cancel/:restoreId", c.CancelRestore)
|
||||
}
|
||||
|
||||
// GetRestores
|
||||
@@ -23,7 +25,7 @@ func (c *RestoreController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// @Tags restores
|
||||
// @Produce json
|
||||
// @Param backupId path string true "Backup ID"
|
||||
// @Success 200 {array} models.Restore
|
||||
// @Success 200 {array} restores_core.Restore
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Router /restores/{backupId} [get]
|
||||
@@ -71,7 +73,7 @@ func (c *RestoreController) RestoreBackup(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var requestDTO RestoreBackupRequest
|
||||
var requestDTO restores_core.RestoreBackupRequest
|
||||
if err := ctx.ShouldBindJSON(&requestDTO); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -84,3 +86,33 @@ func (c *RestoreController) RestoreBackup(ctx *gin.Context) {
|
||||
|
||||
ctx.JSON(http.StatusOK, gin.H{"message": "restore started successfully"})
|
||||
}
|
||||
|
||||
// CancelRestore
|
||||
// @Summary Cancel an in-progress restore
|
||||
// @Description Cancel a restore that is currently in progress
|
||||
// @Tags restores
|
||||
// @Param restoreId path string true "Restore ID"
|
||||
// @Success 204
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Router /restores/cancel/{restoreId} [post]
|
||||
func (c *RestoreController) CancelRestore(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
restoreID, err := uuid.Parse(ctx.Param("restoreId"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid restore ID"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.restoreService.CancelRestore(user, restoreID); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
env_config "databasus-backend/internal/config"
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
@@ -24,16 +24,18 @@ import (
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/databases/databases/mysql"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
"databasus-backend/internal/features/restores/models"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/restores/restoring"
|
||||
"databasus-backend/internal/features/storages"
|
||||
local_storage "databasus-backend/internal/features/storages/models/local"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
users_dto "databasus-backend/internal/features/users/dto"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_models "databasus-backend/internal/features/workspaces/models"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
test_utils "databasus-backend/internal/util/testing"
|
||||
"databasus-backend/internal/util/tools"
|
||||
@@ -43,10 +45,12 @@ func Test_GetRestores_WhenUserIsWorkspaceMember_RestoresReturned(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
defer cleanupDatabaseWithBackup(database, backup)
|
||||
|
||||
var restores []*models.Restore
|
||||
var restores []*restores_core.Restore
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
@@ -65,8 +69,10 @@ func Test_GetRestores_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing.T
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
defer cleanupDatabaseWithBackup(database, backup)
|
||||
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
|
||||
@@ -85,12 +91,14 @@ func Test_GetRestores_WhenUserIsGlobalAdmin_RestoresReturned(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
defer cleanupDatabaseWithBackup(database, backup)
|
||||
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
|
||||
var restores []*models.Restore
|
||||
var restores []*restores_core.Restore
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
@@ -105,15 +113,21 @@ func Test_GetRestores_WhenUserIsGlobalAdmin_RestoresReturned(t *testing.T) {
|
||||
|
||||
func Test_RestoreBackup_WhenUserIsWorkspaceMember_RestoreInitiated(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
|
||||
_, cleanup := SetupMockRestoreNode(t)
|
||||
defer cleanup()
|
||||
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
defer cleanupDatabaseWithBackup(database, backup)
|
||||
|
||||
request := RestoreBackupRequest{
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: env_config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
@@ -136,15 +150,17 @@ func Test_RestoreBackup_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
defer cleanupDatabaseWithBackup(database, backup)
|
||||
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
|
||||
request := RestoreBackupRequest{
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: env_config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
@@ -165,15 +181,21 @@ func Test_RestoreBackup_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing
|
||||
|
||||
func Test_RestoreBackup_WithIsExcludeExtensions_FlagPassedCorrectly(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
|
||||
_, cleanup := SetupMockRestoreNode(t)
|
||||
defer cleanup()
|
||||
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
defer cleanupDatabaseWithBackup(database, backup)
|
||||
|
||||
request := RestoreBackupRequest{
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: env_config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
@@ -195,15 +217,21 @@ func Test_RestoreBackup_WithIsExcludeExtensions_FlagPassedCorrectly(t *testing.T
|
||||
|
||||
func Test_RestoreBackup_AuditLogWritten(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
|
||||
_, cleanup := SetupMockRestoreNode(t)
|
||||
defer cleanup()
|
||||
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
defer cleanupDatabaseWithBackup(database, backup)
|
||||
|
||||
request := RestoreBackupRequest{
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: env_config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
@@ -233,7 +261,7 @@ func Test_RestoreBackup_AuditLogWritten(t *testing.T) {
|
||||
|
||||
found := false
|
||||
for _, log := range auditLogs.AuditLogs {
|
||||
if strings.Contains(log.Message, "Database restored from backup") &&
|
||||
if strings.Contains(log.Message, "Database restored for database") &&
|
||||
strings.Contains(log.Message, database.Name) {
|
||||
found = true
|
||||
break
|
||||
@@ -272,18 +300,29 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
|
||||
// Setup mock node for tests that skip disk validation and reach scheduler
|
||||
if !tc.expectDiskValidated {
|
||||
_, cleanup := SetupMockRestoreNode(t)
|
||||
defer cleanup()
|
||||
}
|
||||
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
var database *databases.Database
|
||||
var backup *backups_core.Backup
|
||||
var request RestoreBackupRequest
|
||||
var storage *storages.Storage
|
||||
var request restores_core.RestoreBackupRequest
|
||||
|
||||
if tc.dbType == databases.DatabaseTypePostgres {
|
||||
_, backup = createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
request = RestoreBackupRequest{
|
||||
database, backup = createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
defer cleanupDatabaseWithBackup(database, backup)
|
||||
request = restores_core.RestoreBackupRequest{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: env_config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
@@ -297,7 +336,16 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
|
||||
owner.Token,
|
||||
router,
|
||||
)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
database = mysqlDB
|
||||
storage = createTestStorage(workspace.ID)
|
||||
|
||||
defer func() {
|
||||
// Cleanup in dependency order: backup -> database -> storage
|
||||
cleanupBackup(backup)
|
||||
databases.RemoveTestDatabase(mysqlDB)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
}()
|
||||
|
||||
configService := backups_config.GetBackupConfigService()
|
||||
config, err := configService.GetBackupConfigByDbId(mysqlDB.ID)
|
||||
@@ -309,11 +357,12 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
|
||||
_, err = configService.SaveBackupConfig(config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backup = createTestBackup(mysqlDB, owner)
|
||||
request = RestoreBackupRequest{
|
||||
backup = createTestBackup(mysqlDB, storage)
|
||||
|
||||
request = restores_core.RestoreBackupRequest{
|
||||
MysqlDatabase: &mysql.MysqlDatabase{
|
||||
Version: tools.MysqlVersion80,
|
||||
Host: "localhost",
|
||||
Host: env_config.GetEnv().TestLocalhost,
|
||||
Port: 3306,
|
||||
Username: "root",
|
||||
Password: "password",
|
||||
@@ -353,16 +402,189 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
router := workspaces_testing.CreateTestRouter(
|
||||
workspaces_controllers.GetWorkspaceController(),
|
||||
workspaces_controllers.GetMembershipController(),
|
||||
databases.GetDatabaseController(),
|
||||
backups_config.GetBackupConfigController(),
|
||||
backups.GetBackupController(),
|
||||
GetRestoreController(),
|
||||
func Test_CancelRestore_InProgressRestore_SuccessfullyCancelled(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
tasks_cancellation.SetupDependencies()
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := createTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backupRepo := backups_core.BackupRepository{}
|
||||
backups, _ := backupRepo.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
restoreRepo := restores_core.RestoreRepository{}
|
||||
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
restores, _ = restoreRepo.FindByStatus(restores_core.RestoreStatusCanceled)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
||||
|
||||
mockUsecase := &restoring.MockBlockingRestoreUsecase{
|
||||
StartedChan: make(chan bool, 1),
|
||||
}
|
||||
restorerNode := restoring.CreateTestRestorerNodeWithUsecase(mockUsecase)
|
||||
|
||||
cancelNode := restoring.StartRestorerNodeForTest(t, restorerNode)
|
||||
defer cancelNode()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
restoreRequest := restores_core.RestoreBackupRequest{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: env_config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
},
|
||||
}
|
||||
|
||||
var restoreResponse map[string]interface{}
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
|
||||
"Bearer "+user.Token,
|
||||
restoreRequest,
|
||||
http.StatusOK,
|
||||
&restoreResponse,
|
||||
)
|
||||
return router
|
||||
|
||||
select {
|
||||
case <-mockUsecase.StartedChan:
|
||||
t.Log("Restore started and is blocking")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Restore did not start within timeout")
|
||||
}
|
||||
|
||||
restoreRepo := &restores_core.RestoreRepository{}
|
||||
restores, err := restoreRepo.FindByBackupID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(t, len(restores), 0, "At least one restore should exist")
|
||||
|
||||
var restoreID uuid.UUID
|
||||
for _, r := range restores {
|
||||
if r.Status == restores_core.RestoreStatusInProgress {
|
||||
restoreID = r.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.NotEqual(t, uuid.Nil, restoreID, "Should find an in-progress restore")
|
||||
|
||||
resp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/restores/cancel/%s", restoreID.String()),
|
||||
"Bearer "+user.Token,
|
||||
nil,
|
||||
http.StatusNoContent,
|
||||
)
|
||||
|
||||
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
|
||||
|
||||
deadline := time.Now().UTC().Add(3 * time.Second)
|
||||
var restore *restores_core.Restore
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
restore, err = restoreRepo.FindByID(restoreID)
|
||||
assert.NoError(t, err)
|
||||
if restore.Status == restores_core.RestoreStatusCanceled {
|
||||
break
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
assert.Equal(t, restores_core.RestoreStatusCanceled, restore.Status)
|
||||
|
||||
auditLogService := audit_logs.GetAuditLogService()
|
||||
auditLogs, err := auditLogService.GetWorkspaceAuditLogs(
|
||||
workspace.ID,
|
||||
&audit_logs.GetAuditLogsRequest{Limit: 100, Offset: 0},
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
foundCancelLog := false
|
||||
for _, log := range auditLogs.AuditLogs {
|
||||
if strings.Contains(log.Message, "Restore cancelled") &&
|
||||
strings.Contains(log.Message, database.Name) {
|
||||
foundCancelLog = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, foundCancelLog, "Cancel audit log should be created")
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_RestoreBackup_WithParallelRestoreInProgress_ReturnsError(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
|
||||
_, cleanup := SetupMockRestoreNode(t)
|
||||
defer cleanup()
|
||||
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
defer cleanupDatabaseWithBackup(database, backup)
|
||||
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: env_config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
},
|
||||
}
|
||||
|
||||
testResp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusOK,
|
||||
)
|
||||
assert.Contains(t, string(testResp.Body), "restore started successfully")
|
||||
|
||||
testResp2 := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp2.Body), "another restore is already in progress")
|
||||
}
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
return CreateTestRouter()
|
||||
}
|
||||
|
||||
func createTestDatabaseWithBackupForRestore(
|
||||
@@ -387,7 +609,7 @@ func createTestDatabaseWithBackupForRestore(
|
||||
panic(err)
|
||||
}
|
||||
|
||||
backup := createTestBackup(database, owner)
|
||||
backup := createTestBackup(database, storage)
|
||||
|
||||
return database, backup
|
||||
}
|
||||
@@ -433,7 +655,7 @@ func createTestMySQLDatabase(
|
||||
token string,
|
||||
router *gin.Engine,
|
||||
) *databases.Database {
|
||||
env := config.GetEnv()
|
||||
env := env_config.GetEnv()
|
||||
portStr := env.TestMysql80Port
|
||||
if portStr == "" {
|
||||
portStr = "33080"
|
||||
@@ -451,7 +673,7 @@ func createTestMySQLDatabase(
|
||||
Type: databases.DatabaseTypeMysql,
|
||||
Mysql: &mysql.MysqlDatabase{
|
||||
Version: tools.MysqlVersion80,
|
||||
Host: "localhost",
|
||||
Host: env_config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
@@ -504,24 +726,14 @@ func createTestStorage(workspaceID uuid.UUID) *storages.Storage {
|
||||
|
||||
func createTestBackup(
|
||||
database *databases.Database,
|
||||
owner *users_dto.SignInResponseDTO,
|
||||
storage *storages.Storage,
|
||||
) *backups_core.Backup {
|
||||
fieldEncryptor := util_encryption.GetFieldEncryptor()
|
||||
userService := users_services.GetUserService()
|
||||
user, err := userService.GetUserFromToken(owner.Token)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
storages, err := storages.GetStorageService().GetStorages(user, *database.WorkspaceID)
|
||||
if err != nil || len(storages) == 0 {
|
||||
panic("No storage found for workspace")
|
||||
}
|
||||
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storages[0].ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10.5,
|
||||
BackupDurationMs: 1000,
|
||||
@@ -536,11 +748,11 @@ func createTestBackup(
|
||||
dummyContent := []byte("dummy backup content for testing")
|
||||
reader := strings.NewReader(string(dummyContent))
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
if err := storages[0].SaveFile(
|
||||
if err := storage.SaveFile(
|
||||
context.Background(),
|
||||
fieldEncryptor,
|
||||
logger,
|
||||
backup.ID,
|
||||
backup.ID.String(),
|
||||
reader,
|
||||
); err != nil {
|
||||
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
|
||||
@@ -548,3 +760,22 @@ func createTestBackup(
|
||||
|
||||
return backup
|
||||
}
|
||||
|
||||
func cleanupDatabaseWithBackup(database *databases.Database, backup *backups_core.Backup) {
|
||||
// Clean up in reverse dependency order
|
||||
cleanupBackup(backup)
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Clean up storage last (after database and backup are removed)
|
||||
configService := backups_config.GetBackupConfigService()
|
||||
config, err := configService.GetBackupConfigByDbId(database.ID)
|
||||
if err == nil && config.StorageID != nil {
|
||||
storages.RemoveTestStorage(*config.StorageID)
|
||||
}
|
||||
}
|
||||
|
||||
func cleanupBackup(backup *backups_core.Backup) {
|
||||
repo := &backups_core.BackupRepository{}
|
||||
repo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package restores
|
||||
package restores_core
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/features/databases/databases/mariadb"
|
||||
@@ -1,4 +1,4 @@
|
||||
package enums
|
||||
package restores_core
|
||||
|
||||
type RestoreStatus string
|
||||
|
||||
@@ -6,4 +6,5 @@ const (
|
||||
RestoreStatusInProgress RestoreStatus = "IN_PROGRESS"
|
||||
RestoreStatusCompleted RestoreStatus = "COMPLETED"
|
||||
RestoreStatusFailed RestoreStatus = "FAILED"
|
||||
RestoreStatusCanceled RestoreStatus = "CANCELED"
|
||||
)
|
||||
23
backend/internal/features/restores/core/interfaces.go
Normal file
23
backend/internal/features/restores/core/interfaces.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package restores_core
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/storages"
|
||||
)
|
||||
|
||||
type RestoreBackupUsecase interface {
|
||||
Execute(
|
||||
ctx context.Context,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore Restore,
|
||||
originalDB *databases.Database,
|
||||
restoringToDB *databases.Database,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
isExcludeExtensions bool,
|
||||
) error
|
||||
}
|
||||
30
backend/internal/features/restores/core/model.go
Normal file
30
backend/internal/features/restores/core/model.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package restores_core
|
||||
|
||||
import (
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/databases/databases/mariadb"
|
||||
"databasus-backend/internal/features/databases/databases/mongodb"
|
||||
"databasus-backend/internal/features/databases/databases/mysql"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type Restore struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
|
||||
Status RestoreStatus `json:"status" gorm:"column:status;type:text;not null"`
|
||||
|
||||
BackupID uuid.UUID `json:"backupId" gorm:"column:backup_id;type:uuid;not null"`
|
||||
Backup *backups_core.Backup
|
||||
|
||||
PostgresqlDatabase *postgresql.PostgresqlDatabase `json:"postgresqlDatabase" gorm:"-"`
|
||||
MysqlDatabase *mysql.MysqlDatabase `json:"mysqlDatabase" gorm:"-"`
|
||||
MariadbDatabase *mariadb.MariadbDatabase `json:"mariadbDatabase" gorm:"-"`
|
||||
MongodbDatabase *mongodb.MongodbDatabase `json:"mongodbDatabase" gorm:"-"`
|
||||
|
||||
FailMessage *string `json:"failMessage" gorm:"column:fail_message"`
|
||||
|
||||
RestoreDurationMs int64 `json:"restoreDurationMs" gorm:"column:restore_duration_ms;default:0"`
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;default:now()"`
|
||||
}
|
||||
91
backend/internal/features/restores/core/repository.go
Normal file
91
backend/internal/features/restores/core/repository.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package restores_core
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/storage"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type RestoreRepository struct{}
|
||||
|
||||
func (r *RestoreRepository) Save(restore *Restore) error {
|
||||
db := storage.GetDb()
|
||||
|
||||
isNew := restore.ID == uuid.Nil
|
||||
if isNew {
|
||||
restore.ID = uuid.New()
|
||||
return db.Create(restore).
|
||||
Omit("Backup", "PostgresqlDatabase", "MysqlDatabase", "MariadbDatabase", "MongodbDatabase").
|
||||
Error
|
||||
}
|
||||
|
||||
return db.Save(restore).
|
||||
Omit("Backup", "PostgresqlDatabase", "MysqlDatabase", "MariadbDatabase", "MongodbDatabase").
|
||||
Error
|
||||
}
|
||||
|
||||
func (r *RestoreRepository) FindByBackupID(backupID uuid.UUID) ([]*Restore, error) {
|
||||
var restores []*Restore
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Backup").
|
||||
Where("backup_id = ?", backupID).
|
||||
Order("created_at DESC").
|
||||
Find(&restores).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return restores, nil
|
||||
}
|
||||
|
||||
func (r *RestoreRepository) FindByID(id uuid.UUID) (*Restore, error) {
|
||||
var restore Restore
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Backup").
|
||||
Where("id = ?", id).
|
||||
First(&restore).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &restore, nil
|
||||
}
|
||||
|
||||
func (r *RestoreRepository) FindByStatus(status RestoreStatus) ([]*Restore, error) {
|
||||
var restores []*Restore
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Backup").
|
||||
Where("status = ?", status).
|
||||
Order("created_at DESC").
|
||||
Find(&restores).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return restores, nil
|
||||
}
|
||||
|
||||
func (r *RestoreRepository) FindInProgressRestoresByDatabaseID(
|
||||
databaseID uuid.UUID,
|
||||
) ([]*Restore, error) {
|
||||
var restores []*Restore
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Backup").
|
||||
Joins("JOIN backups ON backups.id = restores.backup_id").
|
||||
Where("backups.database_id = ? AND restores.status = ?", databaseID, RestoreStatusInProgress).
|
||||
Order("restores.created_at DESC").
|
||||
Find(&restores).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return restores, nil
|
||||
}
|
||||
|
||||
func (r *RestoreRepository) DeleteByID(id uuid.UUID) error {
|
||||
return storage.GetDb().Delete(&Restore{}, "id = ?", id).Error
|
||||
}
|
||||
@@ -1,19 +1,25 @@
|
||||
package restores
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/disk"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/restores/usecases"
|
||||
"databasus-backend/internal/features/storages"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var restoreRepository = &RestoreRepository{}
|
||||
var restoreRepository = &restores_core.RestoreRepository{}
|
||||
var restoreService = &RestoreService{
|
||||
backups.GetBackupService(),
|
||||
restoreRepository,
|
||||
@@ -26,24 +32,32 @@ var restoreService = &RestoreService{
|
||||
audit_logs.GetAuditLogService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
disk.GetDiskService(),
|
||||
tasks_cancellation.GetTaskCancelManager(),
|
||||
}
|
||||
var restoreController = &RestoreController{
|
||||
restoreService,
|
||||
}
|
||||
|
||||
var restoreBackgroundService = &RestoreBackgroundService{
|
||||
restoreRepository,
|
||||
logger.GetLogger(),
|
||||
}
|
||||
|
||||
func GetRestoreController() *RestoreController {
|
||||
return restoreController
|
||||
}
|
||||
|
||||
func GetRestoreBackgroundService() *RestoreBackgroundService {
|
||||
return restoreBackgroundService
|
||||
}
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
backups.GetBackupService().AddBackupRemoveListener(restoreService)
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
backups.GetBackupService().AddBackupRemoveListener(restoreService)
|
||||
backuping.GetBackupCleaner().AddBackupRemoveListener(restoreService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/restores/enums"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type Restore struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
|
||||
Status enums.RestoreStatus `json:"status" gorm:"column:status;type:text;not null"`
|
||||
|
||||
BackupID uuid.UUID `json:"backupId" gorm:"column:backup_id;type:uuid;not null"`
|
||||
Backup *backups_core.Backup
|
||||
|
||||
FailMessage *string `json:"failMessage" gorm:"column:fail_message"`
|
||||
|
||||
RestoreDurationMs int64 `json:"restoreDurationMs" gorm:"column:restore_duration_ms;default:0"`
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;default:now()"`
|
||||
}
|
||||
@@ -1,75 +0,0 @@
|
||||
package restores
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/features/restores/enums"
|
||||
"databasus-backend/internal/features/restores/models"
|
||||
"databasus-backend/internal/storage"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type RestoreRepository struct{}
|
||||
|
||||
func (r *RestoreRepository) Save(restore *models.Restore) error {
|
||||
db := storage.GetDb()
|
||||
|
||||
isNew := restore.ID == uuid.Nil
|
||||
if isNew {
|
||||
restore.ID = uuid.New()
|
||||
return db.Create(restore).
|
||||
Omit("Backup").
|
||||
Error
|
||||
}
|
||||
|
||||
return db.Save(restore).
|
||||
Omit("Backup").
|
||||
Error
|
||||
}
|
||||
|
||||
func (r *RestoreRepository) FindByBackupID(backupID uuid.UUID) ([]*models.Restore, error) {
|
||||
var restores []*models.Restore
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Backup").
|
||||
Where("backup_id = ?", backupID).
|
||||
Order("created_at DESC").
|
||||
Find(&restores).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return restores, nil
|
||||
}
|
||||
|
||||
func (r *RestoreRepository) FindByID(id uuid.UUID) (*models.Restore, error) {
|
||||
var restore models.Restore
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Backup").
|
||||
Where("id = ?", id).
|
||||
First(&restore).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &restore, nil
|
||||
}
|
||||
|
||||
func (r *RestoreRepository) FindByStatus(status enums.RestoreStatus) ([]*models.Restore, error) {
|
||||
var restores []*models.Restore
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Backup").
|
||||
Where("status = ?", status).
|
||||
Order("created_at DESC").
|
||||
Find(&restores).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return restores, nil
|
||||
}
|
||||
|
||||
func (r *RestoreRepository) DeleteByID(id uuid.UUID) error {
|
||||
return storage.GetDb().Delete(&models.Restore{}, "id = ?", id).Error
|
||||
}
|
||||
85
backend/internal/features/restores/restoring/di.go
Normal file
85
backend/internal/features/restores/restoring/di.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package restoring
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/restores/usecases"
|
||||
"databasus-backend/internal/features/storages"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var restoreRepository = &restores_core.RestoreRepository{}
|
||||
|
||||
var restoreNodesRegistry = &RestoreNodesRegistry{
|
||||
client: cache_utils.GetValkeyClient(),
|
||||
logger: logger.GetLogger(),
|
||||
timeout: cache_utils.DefaultCacheTimeout,
|
||||
pubsubRestores: cache_utils.NewPubSubManager(),
|
||||
pubsubCompletions: cache_utils.NewPubSubManager(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
|
||||
var restoreDatabaseCache = cache_utils.NewCacheUtil[RestoreDatabaseCache](
|
||||
cache_utils.GetValkeyClient(),
|
||||
"restore_db:",
|
||||
)
|
||||
|
||||
var restoreCancelManager = tasks_cancellation.GetTaskCancelManager()
|
||||
|
||||
var restorerNode = &RestorerNode{
|
||||
nodeID: uuid.New(),
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
backupService: backups.GetBackupService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
restoreRepository: restoreRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
restoreNodesRegistry: restoreNodesRegistry,
|
||||
logger: logger.GetLogger(),
|
||||
restoreBackupUsecase: usecases.GetRestoreBackupUsecase(),
|
||||
cacheUtil: restoreDatabaseCache,
|
||||
restoreCancelManager: restoreCancelManager,
|
||||
lastHeartbeat: time.Time{},
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
|
||||
var restoresScheduler = &RestoresScheduler{
|
||||
restoreRepository: restoreRepository,
|
||||
backupService: backups.GetBackupService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
restoreNodesRegistry: restoreNodesRegistry,
|
||||
lastCheckTime: time.Now().UTC(),
|
||||
logger: logger.GetLogger(),
|
||||
restoreToNodeRelations: make(map[uuid.UUID]RestoreToNodeRelation),
|
||||
restorerNode: restorerNode,
|
||||
cacheUtil: restoreDatabaseCache,
|
||||
completionSubscriptionID: uuid.Nil,
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
|
||||
func GetRestoresScheduler() *RestoresScheduler {
|
||||
return restoresScheduler
|
||||
}
|
||||
|
||||
func GetRestorerNode() *RestorerNode {
|
||||
return restorerNode
|
||||
}
|
||||
|
||||
func GetRestoreNodesRegistry() *RestoreNodesRegistry {
|
||||
return restoreNodesRegistry
|
||||
}
|
||||
45
backend/internal/features/restores/restoring/dto.go
Normal file
45
backend/internal/features/restores/restoring/dto.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package restoring
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/features/databases/databases/mariadb"
|
||||
"databasus-backend/internal/features/databases/databases/mongodb"
|
||||
"databasus-backend/internal/features/databases/databases/mysql"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type RestoreDatabaseCache struct {
|
||||
PostgresqlDatabase *postgresql.PostgresqlDatabase `json:"postgresqlDatabase,omitempty"`
|
||||
MysqlDatabase *mysql.MysqlDatabase `json:"mysqlDatabase,omitempty"`
|
||||
MariadbDatabase *mariadb.MariadbDatabase `json:"mariadbDatabase,omitempty"`
|
||||
MongodbDatabase *mongodb.MongodbDatabase `json:"mongodbDatabase,omitempty"`
|
||||
}
|
||||
|
||||
type RestoreToNodeRelation struct {
|
||||
NodeID uuid.UUID `json:"nodeId"`
|
||||
RestoreIDs []uuid.UUID `json:"restoreIds"`
|
||||
}
|
||||
|
||||
type RestoreNode struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
ThroughputMBs int `json:"throughputMBs"`
|
||||
LastHeartbeat time.Time `json:"lastHeartbeat"`
|
||||
}
|
||||
|
||||
type RestoreNodeStats struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
ActiveRestores int `json:"activeRestores"`
|
||||
}
|
||||
|
||||
type RestoreSubmitMessage struct {
|
||||
NodeID uuid.UUID `json:"nodeId"`
|
||||
RestoreID uuid.UUID `json:"restoreId"`
|
||||
IsCallNotifier bool `json:"isCallNotifier"`
|
||||
}
|
||||
|
||||
type RestoreCompletionMessage struct {
|
||||
NodeID uuid.UUID `json:"nodeId"`
|
||||
RestoreID uuid.UUID `json:"restoreId"`
|
||||
}
|
||||
88
backend/internal/features/restores/restoring/mocks.go
Normal file
88
backend/internal/features/restores/restoring/mocks.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package restoring
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
)
|
||||
|
||||
type MockSuccessRestoreUsecase struct{}
|
||||
|
||||
func (uc *MockSuccessRestoreUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore restores_core.Restore,
|
||||
originalDB *databases.Database,
|
||||
restoringToDB *databases.Database,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
isExcludeExtensions bool,
|
||||
) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type MockFailedRestoreUsecase struct{}
|
||||
|
||||
func (uc *MockFailedRestoreUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore restores_core.Restore,
|
||||
originalDB *databases.Database,
|
||||
restoringToDB *databases.Database,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
isExcludeExtensions bool,
|
||||
) error {
|
||||
return errors.New("restore failed")
|
||||
}
|
||||
|
||||
type MockCaptureCredentialsRestoreUsecase struct {
|
||||
CalledChan chan *databases.Database
|
||||
ShouldSucceed bool
|
||||
}
|
||||
|
||||
func (uc *MockCaptureCredentialsRestoreUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore restores_core.Restore,
|
||||
originalDB *databases.Database,
|
||||
restoringToDB *databases.Database,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
isExcludeExtensions bool,
|
||||
) error {
|
||||
uc.CalledChan <- restoringToDB
|
||||
|
||||
if uc.ShouldSucceed {
|
||||
return nil
|
||||
}
|
||||
return errors.New("mock restore failed")
|
||||
}
|
||||
|
||||
type MockBlockingRestoreUsecase struct {
|
||||
StartedChan chan bool
|
||||
}
|
||||
|
||||
func (uc *MockBlockingRestoreUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore restores_core.Restore,
|
||||
originalDB *databases.Database,
|
||||
restoringToDB *databases.Database,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
isExcludeExtensions bool,
|
||||
) error {
|
||||
if uc.StartedChan != nil {
|
||||
uc.StartedChan <- true
|
||||
}
|
||||
|
||||
<-ctx.Done()
|
||||
|
||||
return ctx.Err()
|
||||
}
|
||||
649
backend/internal/features/restores/restoring/registry.go
Normal file
649
backend/internal/features/restores/restoring/registry.go
Normal file
@@ -0,0 +1,649 @@
|
||||
package restoring
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/valkey-io/valkey-go"
|
||||
)
|
||||
|
||||
const (
|
||||
nodeInfoKeyPrefix = "restore:node:"
|
||||
nodeInfoKeySuffix = ":info"
|
||||
nodeActiveRestoresPrefix = "restore:node:"
|
||||
nodeActiveRestoresSuffix = ":active_restores"
|
||||
restoreSubmitChannel = "restore:submit"
|
||||
restoreCompletionChannel = "restore:completion"
|
||||
|
||||
deadNodeThreshold = 2 * time.Minute
|
||||
cleanupTickerInterval = 1 * time.Second
|
||||
)
|
||||
|
||||
// RestoreNodesRegistry helps to sync restores scheduler and restore nodes.
|
||||
//
|
||||
// Features:
|
||||
// - Track node availability and load level
|
||||
// - Assign from scheduler to node restores needed to be processed
|
||||
// - Notify scheduler from node about restore completion
|
||||
//
|
||||
// Important things to remember:
|
||||
// - Nodes without heartbeat for more than 2 minutes are not included
|
||||
// in available nodes list and stats
|
||||
//
|
||||
// Cleanup dead nodes performed on 2 levels:
|
||||
// - List and stats functions do not return dead nodes
|
||||
// - Periodically dead nodes are cleaned up in cache (to not
|
||||
// accumulate too many dead nodes in cache)
|
||||
type RestoreNodesRegistry struct {
|
||||
client valkey.Client
|
||||
logger *slog.Logger
|
||||
timeout time.Duration
|
||||
pubsubRestores *cache_utils.PubSubManager
|
||||
pubsubCompletions *cache_utils.PubSubManager
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) Run(ctx context.Context) {
|
||||
wasAlreadyRun := r.hasRun.Load()
|
||||
|
||||
r.runOnce.Do(func() {
|
||||
r.hasRun.Store(true)
|
||||
|
||||
if err := r.cleanupDeadNodes(); err != nil {
|
||||
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(cleanupTickerInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := r.cleanupDeadNodes(); err != nil {
|
||||
r.logger.Error("Failed to cleanup dead nodes", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", r))
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) GetAvailableNodes() ([]RestoreNode, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
var allKeys []string
|
||||
cursor := uint64(0)
|
||||
pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix
|
||||
|
||||
for {
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return nil, fmt.Errorf("failed to scan node keys: %w", result.Error())
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse scan result: %w", err)
|
||||
}
|
||||
|
||||
allKeys = append(allKeys, scanResult.Elements...)
|
||||
|
||||
cursor = scanResult.Cursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(allKeys) == 0 {
|
||||
return []RestoreNode{}, nil
|
||||
}
|
||||
|
||||
keyDataMap, err := r.pipelineGetKeys(allKeys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pipeline get node keys: %w", err)
|
||||
}
|
||||
|
||||
threshold := time.Now().UTC().Add(-deadNodeThreshold)
|
||||
var nodes []RestoreNode
|
||||
|
||||
for key, data := range keyDataMap {
|
||||
// Skip if the key doesn't exist (data is empty)
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var node RestoreNode
|
||||
if err := json.Unmarshal(data, &node); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal node data", "key", key, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip nodes with zero/uninitialized heartbeat
|
||||
if node.LastHeartbeat.IsZero() {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.LastHeartbeat.Before(threshold) {
|
||||
continue
|
||||
}
|
||||
|
||||
nodes = append(nodes, node)
|
||||
}
|
||||
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) GetRestoreNodesStats() ([]RestoreNodeStats, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
var allKeys []string
|
||||
cursor := uint64(0)
|
||||
pattern := nodeActiveRestoresPrefix + "*" + nodeActiveRestoresSuffix
|
||||
|
||||
for {
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(100).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return nil, fmt.Errorf("failed to scan active restores keys: %w", result.Error())
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse scan result: %w", err)
|
||||
}
|
||||
|
||||
allKeys = append(allKeys, scanResult.Elements...)
|
||||
|
||||
cursor = scanResult.Cursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(allKeys) == 0 {
|
||||
return []RestoreNodeStats{}, nil
|
||||
}
|
||||
|
||||
keyDataMap, err := r.pipelineGetKeys(allKeys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pipeline get active restores keys: %w", err)
|
||||
}
|
||||
|
||||
var nodeInfoKeys []string
|
||||
nodeIDToStatsKey := make(map[string]string)
|
||||
for key := range keyDataMap {
|
||||
nodeID := r.extractNodeIDFromKey(key, nodeActiveRestoresPrefix, nodeActiveRestoresSuffix)
|
||||
nodeIDStr := nodeID.String()
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeIDStr, nodeInfoKeySuffix)
|
||||
nodeInfoKeys = append(nodeInfoKeys, infoKey)
|
||||
nodeIDToStatsKey[infoKey] = key
|
||||
}
|
||||
|
||||
nodeInfoMap, err := r.pipelineGetKeys(nodeInfoKeys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pipeline get node info keys: %w", err)
|
||||
}
|
||||
|
||||
threshold := time.Now().UTC().Add(-deadNodeThreshold)
|
||||
var stats []RestoreNodeStats
|
||||
for infoKey, nodeData := range nodeInfoMap {
|
||||
// Skip if the info key doesn't exist (nodeData is empty)
|
||||
if len(nodeData) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var node RestoreNode
|
||||
if err := json.Unmarshal(nodeData, &node); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal node data", "key", infoKey, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip nodes with zero/uninitialized heartbeat
|
||||
if node.LastHeartbeat.IsZero() {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.LastHeartbeat.Before(threshold) {
|
||||
continue
|
||||
}
|
||||
|
||||
statsKey := nodeIDToStatsKey[infoKey]
|
||||
tasksData := keyDataMap[statsKey]
|
||||
count, err := r.parseIntFromBytes(tasksData)
|
||||
if err != nil {
|
||||
r.logger.Warn("Failed to parse active restores count", "key", statsKey, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
stat := RestoreNodeStats{
|
||||
ID: node.ID,
|
||||
ActiveRestores: int(count),
|
||||
}
|
||||
stats = append(stats, stat)
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) IncrementRestoresInProgress(nodeID uuid.UUID) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf(
|
||||
"%s%s%s",
|
||||
nodeActiveRestoresPrefix,
|
||||
nodeID.String(),
|
||||
nodeActiveRestoresSuffix,
|
||||
)
|
||||
result := r.client.Do(ctx, r.client.B().Incr().Key(key).Build())
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to increment restores in progress for node %s: %w",
|
||||
nodeID,
|
||||
result.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) DecrementRestoresInProgress(nodeID uuid.UUID) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf(
|
||||
"%s%s%s",
|
||||
nodeActiveRestoresPrefix,
|
||||
nodeID.String(),
|
||||
nodeActiveRestoresSuffix,
|
||||
)
|
||||
result := r.client.Do(ctx, r.client.B().Decr().Key(key).Build())
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to decrement restores in progress for node %s: %w",
|
||||
nodeID,
|
||||
result.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
newValue, err := result.AsInt64()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse decremented value for node %s: %w", nodeID, err)
|
||||
}
|
||||
|
||||
if newValue < 0 {
|
||||
setCtx, setCancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
r.client.Do(setCtx, r.client.B().Set().Key(key).Value("0").Build())
|
||||
setCancel()
|
||||
r.logger.Warn("Active restores counter went below 0, reset to 0", "nodeID", nodeID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) HearthbeatNodeInRegistry(
|
||||
now time.Time,
|
||||
restoreNode RestoreNode,
|
||||
) error {
|
||||
if now.IsZero() {
|
||||
return fmt.Errorf("cannot register node with zero heartbeat timestamp")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
restoreNode.LastHeartbeat = now
|
||||
|
||||
data, err := json.Marshal(restoreNode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal restore node: %w", err)
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, restoreNode.ID.String(), nodeInfoKeySuffix)
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Set().Key(key).Value(string(data)).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to register node %s: %w", restoreNode.ID, result.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) UnregisterNodeFromRegistry(restoreNode RestoreNode) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, restoreNode.ID.String(), nodeInfoKeySuffix)
|
||||
counterKey := fmt.Sprintf(
|
||||
"%s%s%s",
|
||||
nodeActiveRestoresPrefix,
|
||||
restoreNode.ID.String(),
|
||||
nodeActiveRestoresSuffix,
|
||||
)
|
||||
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Del().Key(infoKey, counterKey).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to unregister node %s: %w", restoreNode.ID, result.Error())
|
||||
}
|
||||
|
||||
r.logger.Info("Unregistered node from registry", "nodeID", restoreNode.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) AssignRestoreToNode(
|
||||
targetNodeID uuid.UUID,
|
||||
restoreID uuid.UUID,
|
||||
isCallNotifier bool,
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
message := RestoreSubmitMessage{
|
||||
NodeID: targetNodeID,
|
||||
RestoreID: restoreID,
|
||||
IsCallNotifier: isCallNotifier,
|
||||
}
|
||||
|
||||
messageJSON, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal restore submit message: %w", err)
|
||||
}
|
||||
|
||||
err = r.pubsubRestores.Publish(ctx, restoreSubmitChannel, string(messageJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to publish restore submit message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) SubscribeNodeForRestoresAssignment(
|
||||
nodeID uuid.UUID,
|
||||
handler func(restoreID uuid.UUID, isCallNotifier bool),
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
wrappedHandler := func(message string) {
|
||||
var msg RestoreSubmitMessage
|
||||
if err := json.Unmarshal([]byte(message), &msg); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal restore submit message", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if msg.NodeID != nodeID {
|
||||
return
|
||||
}
|
||||
|
||||
handler(msg.RestoreID, msg.IsCallNotifier)
|
||||
}
|
||||
|
||||
err := r.pubsubRestores.Subscribe(ctx, restoreSubmitChannel, wrappedHandler)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to restore submit channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Subscribed to restore submit channel", "nodeID", nodeID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) UnsubscribeNodeForRestoresAssignments() error {
|
||||
err := r.pubsubRestores.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unsubscribe from restore submit channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Unsubscribed from restore submit channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) PublishRestoreCompletion(
|
||||
nodeID uuid.UUID,
|
||||
restoreID uuid.UUID,
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
message := RestoreCompletionMessage{
|
||||
NodeID: nodeID,
|
||||
RestoreID: restoreID,
|
||||
}
|
||||
|
||||
messageJSON, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal restore completion message: %w", err)
|
||||
}
|
||||
|
||||
err = r.pubsubCompletions.Publish(ctx, restoreCompletionChannel, string(messageJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to publish restore completion message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) SubscribeForRestoresCompletions(
|
||||
handler func(nodeID uuid.UUID, restoreID uuid.UUID),
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
wrappedHandler := func(message string) {
|
||||
var msg RestoreCompletionMessage
|
||||
if err := json.Unmarshal([]byte(message), &msg); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal restore completion message", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
handler(msg.NodeID, msg.RestoreID)
|
||||
}
|
||||
|
||||
err := r.pubsubCompletions.Subscribe(ctx, restoreCompletionChannel, wrappedHandler)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to restore completion channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Subscribed to restore completion channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) UnsubscribeForRestoresCompletions() error {
|
||||
err := r.pubsubCompletions.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unsubscribe from restore completion channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Unsubscribed from restore completion channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) extractNodeIDFromKey(key, prefix, suffix string) uuid.UUID {
|
||||
nodeIDStr := strings.TrimPrefix(key, prefix)
|
||||
nodeIDStr = strings.TrimSuffix(nodeIDStr, suffix)
|
||||
|
||||
nodeID, err := uuid.Parse(nodeIDStr)
|
||||
if err != nil {
|
||||
r.logger.Warn("Failed to parse node ID from key", "key", key, "error", err)
|
||||
return uuid.Nil
|
||||
}
|
||||
|
||||
return nodeID
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) pipelineGetKeys(keys []string) (map[string][]byte, error) {
|
||||
if len(keys) == 0 {
|
||||
return make(map[string][]byte), nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
commands := make([]valkey.Completed, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
commands = append(commands, r.client.B().Get().Key(key).Build())
|
||||
}
|
||||
|
||||
results := r.client.DoMulti(ctx, commands...)
|
||||
|
||||
keyDataMap := make(map[string][]byte, len(keys))
|
||||
for i, result := range results {
|
||||
if result.Error() != nil {
|
||||
r.logger.Warn("Failed to get key in pipeline", "key", keys[i], "error", result.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
data, err := result.AsBytes()
|
||||
if err != nil {
|
||||
r.logger.Warn("Failed to parse key data in pipeline", "key", keys[i], "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
keyDataMap[keys[i]] = data
|
||||
}
|
||||
|
||||
return keyDataMap, nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) parseIntFromBytes(data []byte) (int64, error) {
|
||||
str := string(data)
|
||||
var count int64
|
||||
_, err := fmt.Sscanf(str, "%d", &count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to parse integer from bytes: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) cleanupDeadNodes() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
var allKeys []string
|
||||
cursor := uint64(0)
|
||||
pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix
|
||||
|
||||
for {
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to scan node keys: %w", result.Error())
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse scan result: %w", err)
|
||||
}
|
||||
|
||||
allKeys = append(allKeys, scanResult.Elements...)
|
||||
|
||||
cursor = scanResult.Cursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(allKeys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
keyDataMap, err := r.pipelineGetKeys(allKeys)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to pipeline get node keys: %w", err)
|
||||
}
|
||||
|
||||
threshold := time.Now().UTC().Add(-deadNodeThreshold)
|
||||
var deadNodeKeys []string
|
||||
|
||||
for key, data := range keyDataMap {
|
||||
// Skip if the key doesn't exist (data is empty)
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var node RestoreNode
|
||||
if err := json.Unmarshal(data, &node); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal node data during cleanup", "key", key, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip nodes with zero/uninitialized heartbeat
|
||||
if node.LastHeartbeat.IsZero() {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.LastHeartbeat.Before(threshold) {
|
||||
nodeID := node.ID.String()
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeID, nodeInfoKeySuffix)
|
||||
statsKey := fmt.Sprintf(
|
||||
"%s%s%s",
|
||||
nodeActiveRestoresPrefix,
|
||||
nodeID,
|
||||
nodeActiveRestoresSuffix,
|
||||
)
|
||||
|
||||
deadNodeKeys = append(deadNodeKeys, infoKey, statsKey)
|
||||
r.logger.Info(
|
||||
"Marking node for cleanup",
|
||||
"nodeID", nodeID,
|
||||
"lastHeartbeat", node.LastHeartbeat,
|
||||
"threshold", threshold,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if len(deadNodeKeys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
delCtx, delCancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer delCancel()
|
||||
|
||||
result := r.client.Do(
|
||||
delCtx,
|
||||
r.client.B().Del().Key(deadNodeKeys...).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to delete dead node keys: %w", result.Error())
|
||||
}
|
||||
|
||||
deletedCount, err := result.AsInt64()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse deleted count: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Cleaned up dead nodes", "deletedKeysCount", deletedCount)
|
||||
return nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
310
backend/internal/features/restores/restoring/restorer.go
Normal file
310
backend/internal/features/restores/restoring/restorer.go
Normal file
@@ -0,0 +1,310 @@
|
||||
package restoring
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
)
|
||||
|
||||
const (
|
||||
heartbeatTickerInterval = 15 * time.Second
|
||||
restorerHealthcheckThreshold = 5 * time.Minute
|
||||
)
|
||||
|
||||
type RestorerNode struct {
|
||||
nodeID uuid.UUID
|
||||
|
||||
databaseService *databases.DatabaseService
|
||||
backupService *backups.BackupService
|
||||
fieldEncryptor util_encryption.FieldEncryptor
|
||||
restoreRepository *restores_core.RestoreRepository
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
storageService *storages.StorageService
|
||||
restoreNodesRegistry *RestoreNodesRegistry
|
||||
logger *slog.Logger
|
||||
restoreBackupUsecase restores_core.RestoreBackupUsecase
|
||||
cacheUtil *cache_utils.CacheUtil[RestoreDatabaseCache]
|
||||
restoreCancelManager *tasks_cancellation.TaskCancelManager
|
||||
|
||||
lastHeartbeat time.Time
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (n *RestorerNode) Run(ctx context.Context) {
|
||||
wasAlreadyRun := n.hasRun.Load()
|
||||
|
||||
n.runOnce.Do(func() {
|
||||
n.hasRun.Store(true)
|
||||
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
|
||||
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
|
||||
|
||||
restoreNode := RestoreNode{
|
||||
ID: n.nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
}
|
||||
|
||||
if err := n.restoreNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), restoreNode); err != nil {
|
||||
n.logger.Error("Failed to register node in registry", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
restoreHandler := func(restoreID uuid.UUID, isCallNotifier bool) {
|
||||
n.MakeRestore(restoreID)
|
||||
if err := n.restoreNodesRegistry.PublishRestoreCompletion(n.nodeID, restoreID); err != nil {
|
||||
n.logger.Error(
|
||||
"Failed to publish restore completion",
|
||||
"error",
|
||||
err,
|
||||
"restoreID",
|
||||
restoreID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
err := n.restoreNodesRegistry.SubscribeNodeForRestoresAssignment(
|
||||
n.nodeID,
|
||||
restoreHandler,
|
||||
)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to subscribe to restore assignments", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := n.restoreNodesRegistry.UnsubscribeNodeForRestoresAssignments(); err != nil {
|
||||
n.logger.Error("Failed to unsubscribe from restore assignments", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(heartbeatTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
n.logger.Info("Restore node started", "nodeID", n.nodeID, "throughput", throughputMBs)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
|
||||
|
||||
if err := n.restoreNodesRegistry.UnregisterNodeFromRegistry(restoreNode); err != nil {
|
||||
n.logger.Error("Failed to unregister node from registry", "error", err)
|
||||
}
|
||||
|
||||
return
|
||||
case <-ticker.C:
|
||||
n.sendHeartbeat(&restoreNode)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", n))
|
||||
}
|
||||
}
|
||||
|
||||
func (n *RestorerNode) IsRestorerRunning() bool {
|
||||
return n.lastHeartbeat.After(time.Now().UTC().Add(-restorerHealthcheckThreshold))
|
||||
}
|
||||
|
||||
func (n *RestorerNode) MakeRestore(restoreID uuid.UUID) {
|
||||
// Get and delete cached DB credentials atomically
|
||||
dbCache := n.cacheUtil.GetAndDelete(restoreID.String())
|
||||
|
||||
if dbCache == nil {
|
||||
// Cache miss - fail immediately
|
||||
restore, err := n.restoreRepository.FindByID(restoreID)
|
||||
if err != nil {
|
||||
n.logger.Error(
|
||||
"Failed to get restore by ID after cache miss",
|
||||
"restoreId",
|
||||
restoreID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
errMsg := "Database credentials expired or missing from cache (most likely due to instance restart)"
|
||||
restore.FailMessage = &errMsg
|
||||
restore.Status = restores_core.RestoreStatusFailed
|
||||
|
||||
if err := n.restoreRepository.Save(restore); err != nil {
|
||||
n.logger.Error("Failed to save restore after cache miss", "error", err)
|
||||
}
|
||||
|
||||
n.logger.Error("Restore failed: cache miss", "restoreId", restoreID)
|
||||
return
|
||||
}
|
||||
|
||||
restore, err := n.restoreRepository.FindByID(restoreID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get restore by ID", "restoreId", restoreID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
backup, err := n.backupService.GetBackup(restore.BackupID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get backup by ID", "backupId", restore.BackupID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
databaseID := backup.DatabaseID
|
||||
|
||||
database, err := n.databaseService.GetDatabaseByID(databaseID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get database by ID", "databaseId", databaseID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
backupConfig, err := n.backupConfigService.GetBackupConfigByDbId(databaseID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if backupConfig.StorageID == nil {
|
||||
n.logger.Error("Backup config storage ID is not defined")
|
||||
return
|
||||
}
|
||||
|
||||
storage, err := n.storageService.GetStorageByID(*backupConfig.StorageID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get storage by ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now().UTC()
|
||||
|
||||
// Create cancellable context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
n.restoreCancelManager.RegisterTask(restore.ID, cancel)
|
||||
defer n.restoreCancelManager.UnregisterTask(restore.ID)
|
||||
|
||||
// Create restoring database from cached credentials
|
||||
restoringToDB := &databases.Database{
|
||||
Type: database.Type,
|
||||
Postgresql: dbCache.PostgresqlDatabase,
|
||||
Mysql: dbCache.MysqlDatabase,
|
||||
Mariadb: dbCache.MariadbDatabase,
|
||||
Mongodb: dbCache.MongodbDatabase,
|
||||
}
|
||||
|
||||
if err := restoringToDB.PopulateDbData(n.logger, n.fieldEncryptor); err != nil {
|
||||
errMsg := fmt.Sprintf("failed to auto-detect database data: %v", err)
|
||||
restore.FailMessage = &errMsg
|
||||
restore.Status = restores_core.RestoreStatusFailed
|
||||
restore.RestoreDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
if err := n.restoreRepository.Save(restore); err != nil {
|
||||
n.logger.Error("Failed to save restore", "error", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
isExcludeExtensions := false
|
||||
if dbCache.PostgresqlDatabase != nil {
|
||||
isExcludeExtensions = dbCache.PostgresqlDatabase.IsExcludeExtensions
|
||||
}
|
||||
|
||||
err = n.restoreBackupUsecase.Execute(
|
||||
ctx,
|
||||
backupConfig,
|
||||
*restore,
|
||||
database,
|
||||
restoringToDB,
|
||||
backup,
|
||||
storage,
|
||||
isExcludeExtensions,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
|
||||
// Check if restore was cancelled
|
||||
isCancelled := strings.Contains(errMsg, "restore cancelled") ||
|
||||
strings.Contains(errMsg, "context canceled") ||
|
||||
errors.Is(err, context.Canceled)
|
||||
isShutdown := strings.Contains(errMsg, "shutdown")
|
||||
|
||||
if isCancelled && !isShutdown {
|
||||
n.logger.Warn("Restore was cancelled by user or system",
|
||||
"restoreId", restore.ID,
|
||||
"isCancelled", isCancelled,
|
||||
"isShutdown", isShutdown,
|
||||
)
|
||||
|
||||
restore.Status = restores_core.RestoreStatusCanceled
|
||||
restore.RestoreDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
if err := n.restoreRepository.Save(restore); err != nil {
|
||||
n.logger.Error("Failed to save cancelled restore", "error", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
n.logger.Error("Restore execution failed",
|
||||
"restoreId", restore.ID,
|
||||
"backupId", backup.ID,
|
||||
"databaseId", databaseID,
|
||||
"databaseType", database.Type,
|
||||
"storageId", storage.ID,
|
||||
"storageType", storage.Type,
|
||||
"error", err,
|
||||
"errorMessage", errMsg,
|
||||
)
|
||||
|
||||
restore.FailMessage = &errMsg
|
||||
restore.Status = restores_core.RestoreStatusFailed
|
||||
restore.RestoreDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
if err := n.restoreRepository.Save(restore); err != nil {
|
||||
n.logger.Error("Failed to save restore", "error", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
restore.Status = restores_core.RestoreStatusCompleted
|
||||
restore.RestoreDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
if err := n.restoreRepository.Save(restore); err != nil {
|
||||
n.logger.Error("Failed to save restore", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
n.logger.Info(
|
||||
"Restore completed successfully",
|
||||
"restoreId", restore.ID,
|
||||
"backupId", backup.ID,
|
||||
"durationMs", restore.RestoreDurationMs,
|
||||
)
|
||||
}
|
||||
|
||||
func (n *RestorerNode) sendHeartbeat(restoreNode *RestoreNode) {
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
if err := n.restoreNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *restoreNode); err != nil {
|
||||
n.logger.Error("Failed to send heartbeat", "error", err)
|
||||
}
|
||||
}
|
||||
164
backend/internal/features/restores/restoring/restorer_test.go
Normal file
164
backend/internal/features/restores/restoring/restorer_test.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package restoring
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
)
|
||||
|
||||
func Test_MakeRestore_WhenCacheMissed_RestoreFails(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
defer func() {
|
||||
backupRepo := backups_core.BackupRepository{}
|
||||
backupsList, _ := backupRepo.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backupsList {
|
||||
backupRepo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
restoreRepo := restores_core.RestoreRepository{}
|
||||
restoresInProgress, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
|
||||
for _, restore := range restoresInProgress {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
restoresFailed, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusFailed)
|
||||
for _, restore := range restoresFailed {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
||||
|
||||
// Create restore but DON'T cache DB credentials
|
||||
// Also don't set embedded DB fields to avoid schema issues
|
||||
restore := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
}
|
||||
err := restoreRepository.Save(restore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create restorer and execute restore (should fail due to cache miss)
|
||||
restorerNode := CreateTestRestorerNode()
|
||||
restorerNode.MakeRestore(restore.ID)
|
||||
|
||||
// Verify restore failed with appropriate error message
|
||||
updatedRestore, err := restoreRepository.FindByID(restore.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, restores_core.RestoreStatusFailed, updatedRestore.Status)
|
||||
assert.NotNil(t, updatedRestore.FailMessage)
|
||||
assert.Contains(
|
||||
t,
|
||||
*updatedRestore.FailMessage,
|
||||
"Database credentials expired or missing from cache",
|
||||
)
|
||||
}
|
||||
|
||||
func Test_MakeRestore_WhenTaskStarts_CacheDeletedImmediately(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
defer func() {
|
||||
backupRepo := backups_core.BackupRepository{}
|
||||
backupsList, _ := backupRepo.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backupsList {
|
||||
backupRepo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
restoreRepo := restores_core.RestoreRepository{}
|
||||
restoresInProgress, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
|
||||
for _, restore := range restoresInProgress {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
restoresFailed, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusFailed)
|
||||
for _, restore := range restoresFailed {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
restoresCompleted, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusCompleted)
|
||||
for _, restore := range restoresCompleted {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
||||
|
||||
// Create restore with cached DB credentials
|
||||
// Don't set embedded DB fields in the restore model itself
|
||||
restore := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
}
|
||||
err := restoreRepository.Save(restore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Cache DB credentials separately
|
||||
dbCache := &RestoreDatabaseCache{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "test",
|
||||
Password: "test",
|
||||
Database: stringPtr("testdb"),
|
||||
Version: "16",
|
||||
},
|
||||
}
|
||||
restoreDatabaseCache.SetWithExpiration(restore.ID.String(), dbCache, 1*time.Hour)
|
||||
|
||||
// Verify cache exists before restore starts
|
||||
cachedDB := restoreDatabaseCache.Get(restore.ID.String())
|
||||
assert.NotNil(t, cachedDB, "Cache should exist before restore starts")
|
||||
|
||||
// Start restore (this will call GetAndDelete)
|
||||
restorerNode := CreateTestRestorerNode()
|
||||
restorerNode.MakeRestore(restore.ID)
|
||||
|
||||
// Verify cache was deleted immediately
|
||||
cachedDBAfter := restoreDatabaseCache.Get(restore.ID.String())
|
||||
assert.Nil(t, cachedDBAfter, "Cache should be deleted immediately when task starts")
|
||||
}
|
||||
410
backend/internal/features/restores/restoring/scheduler.go
Normal file
410
backend/internal/features/restores/restoring/scheduler.go
Normal file
@@ -0,0 +1,410 @@
|
||||
package restoring
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
)
|
||||
|
||||
const (
|
||||
schedulerStartupDelay = 1 * time.Minute
|
||||
schedulerTickerInterval = 1 * time.Minute
|
||||
schedulerHealthcheckThreshold = 5 * time.Minute
|
||||
)
|
||||
|
||||
type RestoresScheduler struct {
|
||||
restoreRepository *restores_core.RestoreRepository
|
||||
backupService *backups.BackupService
|
||||
storageService *storages.StorageService
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
restoreNodesRegistry *RestoreNodesRegistry
|
||||
lastCheckTime time.Time
|
||||
logger *slog.Logger
|
||||
restoreToNodeRelations map[uuid.UUID]RestoreToNodeRelation
|
||||
restorerNode *RestorerNode
|
||||
cacheUtil *cache_utils.CacheUtil[RestoreDatabaseCache]
|
||||
completionSubscriptionID uuid.UUID
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *RestoresScheduler) Run(ctx context.Context) {
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
|
||||
s.lastCheckTime = time.Now().UTC()
|
||||
|
||||
if config.GetEnv().IsManyNodesMode {
|
||||
// wait other nodes to start
|
||||
time.Sleep(schedulerStartupDelay)
|
||||
}
|
||||
|
||||
if err := s.failRestoresInProgress(); err != nil {
|
||||
s.logger.Error("Failed to fail restores in progress", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err := s.restoreNodesRegistry.SubscribeForRestoresCompletions(s.onRestoreCompleted)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to subscribe to restore completions", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := s.restoreNodesRegistry.UnsubscribeForRestoresCompletions(); err != nil {
|
||||
s.logger.Error("Failed to unsubscribe from restore completions", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(schedulerTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.checkDeadNodesAndFailRestores(); err != nil {
|
||||
s.logger.Error("Failed to check dead nodes and fail restores", "error", err)
|
||||
}
|
||||
|
||||
s.lastCheckTime = time.Now().UTC()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RestoresScheduler) IsSchedulerRunning() bool {
|
||||
return s.lastCheckTime.After(time.Now().UTC().Add(-schedulerHealthcheckThreshold))
|
||||
}
|
||||
|
||||
func (s *RestoresScheduler) failRestoresInProgress() error {
|
||||
restoresInProgress, err := s.restoreRepository.FindByStatus(
|
||||
restores_core.RestoreStatusInProgress,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, restore := range restoresInProgress {
|
||||
failMessage := "Restore failed due to application restart"
|
||||
restore.FailMessage = &failMessage
|
||||
restore.Status = restores_core.RestoreStatusFailed
|
||||
|
||||
if err := s.restoreRepository.Save(restore); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RestoresScheduler) StartRestore(restoreID uuid.UUID, dbCache *RestoreDatabaseCache) error {
|
||||
// If dbCache not provided, try to fetch from DB (for backward compatibility/testing)
|
||||
if dbCache == nil {
|
||||
restore, err := s.restoreRepository.FindByID(restoreID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to find restore by ID",
|
||||
"restoreId",
|
||||
restoreID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// Create cache DTO from restore (may be nil if not in DB)
|
||||
dbCache = &RestoreDatabaseCache{
|
||||
PostgresqlDatabase: restore.PostgresqlDatabase,
|
||||
MysqlDatabase: restore.MysqlDatabase,
|
||||
MariadbDatabase: restore.MariadbDatabase,
|
||||
MongodbDatabase: restore.MongodbDatabase,
|
||||
}
|
||||
}
|
||||
|
||||
// Cache database credentials with 1-hour expiration
|
||||
s.cacheUtil.SetWithExpiration(restoreID.String(), dbCache, 1*time.Hour)
|
||||
|
||||
leastBusyNodeID, err := s.calculateLeastBusyNode()
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to calculate least busy node",
|
||||
"restoreId",
|
||||
restoreID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.restoreNodesRegistry.IncrementRestoresInProgress(*leastBusyNodeID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to increment restores in progress",
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
"restoreId",
|
||||
restoreID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.restoreNodesRegistry.AssignRestoreToNode(*leastBusyNodeID, restoreID, false); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to submit restore",
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
"restoreId",
|
||||
restoreID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
if decrementErr := s.restoreNodesRegistry.DecrementRestoresInProgress(*leastBusyNodeID); decrementErr != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement restores in progress after submit failure",
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
"error",
|
||||
decrementErr,
|
||||
)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if relation, exists := s.restoreToNodeRelations[*leastBusyNodeID]; exists {
|
||||
relation.RestoreIDs = append(relation.RestoreIDs, restoreID)
|
||||
s.restoreToNodeRelations[*leastBusyNodeID] = relation
|
||||
} else {
|
||||
s.restoreToNodeRelations[*leastBusyNodeID] = RestoreToNodeRelation{
|
||||
NodeID: *leastBusyNodeID,
|
||||
RestoreIDs: []uuid.UUID{restoreID},
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Info(
|
||||
"Successfully triggered restore",
|
||||
"restoreId",
|
||||
restoreID,
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RestoresScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
|
||||
nodes, err := s.restoreNodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get available nodes: %w", err)
|
||||
}
|
||||
|
||||
if len(nodes) == 0 {
|
||||
return nil, fmt.Errorf("no nodes available")
|
||||
}
|
||||
|
||||
stats, err := s.restoreNodesRegistry.GetRestoreNodesStats()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get restore nodes stats: %w", err)
|
||||
}
|
||||
|
||||
statsMap := make(map[uuid.UUID]int)
|
||||
for _, stat := range stats {
|
||||
statsMap[stat.ID] = stat.ActiveRestores
|
||||
}
|
||||
|
||||
var bestNode *RestoreNode
|
||||
var bestScore float64 = -1
|
||||
|
||||
for i := range nodes {
|
||||
node := &nodes[i]
|
||||
|
||||
activeRestores := statsMap[node.ID]
|
||||
|
||||
var score float64
|
||||
if node.ThroughputMBs > 0 {
|
||||
score = float64(activeRestores) / float64(node.ThroughputMBs)
|
||||
} else {
|
||||
score = float64(activeRestores) * 1000
|
||||
}
|
||||
|
||||
if bestNode == nil || score < bestScore {
|
||||
bestNode = node
|
||||
bestScore = score
|
||||
}
|
||||
}
|
||||
|
||||
if bestNode == nil {
|
||||
return nil, fmt.Errorf("no suitable nodes available")
|
||||
}
|
||||
|
||||
return &bestNode.ID, nil
|
||||
}
|
||||
|
||||
func (s *RestoresScheduler) onRestoreCompleted(nodeID uuid.UUID, restoreID uuid.UUID) {
|
||||
// Verify this task is actually a restore (registry contains multiple task types)
|
||||
_, err := s.restoreRepository.FindByID(restoreID)
|
||||
if err != nil {
|
||||
// Not a restore task, ignore it
|
||||
return
|
||||
}
|
||||
|
||||
relation, exists := s.restoreToNodeRelations[nodeID]
|
||||
if !exists {
|
||||
s.logger.Warn(
|
||||
"Received completion for unknown node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"restoreId",
|
||||
restoreID,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
newRestoreIDs := make([]uuid.UUID, 0)
|
||||
found := false
|
||||
for _, id := range relation.RestoreIDs {
|
||||
if id == restoreID {
|
||||
found = true
|
||||
continue
|
||||
}
|
||||
newRestoreIDs = append(newRestoreIDs, id)
|
||||
}
|
||||
|
||||
if !found {
|
||||
s.logger.Warn(
|
||||
"Restore not found in node's restore list",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"restoreId",
|
||||
restoreID,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if len(newRestoreIDs) == 0 {
|
||||
delete(s.restoreToNodeRelations, nodeID)
|
||||
} else {
|
||||
relation.RestoreIDs = newRestoreIDs
|
||||
s.restoreToNodeRelations[nodeID] = relation
|
||||
}
|
||||
|
||||
if err := s.restoreNodesRegistry.DecrementRestoresInProgress(nodeID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement restores in progress",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"restoreId",
|
||||
restoreID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RestoresScheduler) checkDeadNodesAndFailRestores() error {
|
||||
nodes, err := s.restoreNodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get available nodes: %w", err)
|
||||
}
|
||||
|
||||
aliveNodeIDs := make(map[uuid.UUID]bool)
|
||||
for _, node := range nodes {
|
||||
aliveNodeIDs[node.ID] = true
|
||||
}
|
||||
|
||||
for nodeID, relation := range s.restoreToNodeRelations {
|
||||
if aliveNodeIDs[nodeID] {
|
||||
continue
|
||||
}
|
||||
|
||||
s.logger.Warn(
|
||||
"Node is dead, failing its restores",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"restoreCount",
|
||||
len(relation.RestoreIDs),
|
||||
)
|
||||
|
||||
for _, restoreID := range relation.RestoreIDs {
|
||||
restore, err := s.restoreRepository.FindByID(restoreID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to find restore for dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"restoreId",
|
||||
restoreID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
failMessage := "Restore failed due to node unavailability"
|
||||
restore.FailMessage = &failMessage
|
||||
restore.Status = restores_core.RestoreStatusFailed
|
||||
|
||||
if err := s.restoreRepository.Save(restore); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to save failed restore for dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"restoreId",
|
||||
restoreID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := s.restoreNodesRegistry.DecrementRestoresInProgress(nodeID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement restores in progress for dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"restoreId",
|
||||
restoreID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
s.logger.Info(
|
||||
"Failed restore due to dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"restoreId",
|
||||
restoreID,
|
||||
)
|
||||
}
|
||||
|
||||
delete(s.restoreToNodeRelations, nodeID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
856
backend/internal/features/restores/restoring/scheduler_test.go
Normal file
856
backend/internal/features/restores/restoring/scheduler_test.go
Normal file
@@ -0,0 +1,856 @@
|
||||
package restoring
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_CheckDeadNodesAndFailRestores_NodeDies_FailsRestoreAndCleansUpRegistry(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
var mockNodeID uuid.UUID
|
||||
|
||||
defer func() {
|
||||
backupRepo := backups_core.BackupRepository{}
|
||||
backups, _ := backupRepo.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
restoreRepo := restores_core.RestoreRepository{}
|
||||
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
restores, _ = restoreRepo.FindByStatus(restores_core.RestoreStatusFailed)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
// Clean up mock node
|
||||
if mockNodeID != uuid.Nil {
|
||||
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: mockNodeID})
|
||||
}
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
// Create a test backup
|
||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
||||
|
||||
var err error
|
||||
// Register mock node without subscribing to restores (simulates node crash after registration)
|
||||
mockNodeID = uuid.New()
|
||||
err = CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC())
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create restore and assign to mock node
|
||||
restore := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
}
|
||||
err = restoreRepository.Save(restore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Scheduler assigns restore to mock node
|
||||
err = GetRestoresScheduler().StartRestore(restore.ID, nil)
|
||||
assert.NoError(t, err)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify Valkey counter was incremented when restore was assigned
|
||||
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
|
||||
assert.NoError(t, err)
|
||||
foundStat := false
|
||||
for _, stat := range stats {
|
||||
if stat.ID == mockNodeID {
|
||||
assert.Equal(t, 1, stat.ActiveRestores)
|
||||
foundStat = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, foundStat, "Node stats should be present")
|
||||
|
||||
// Simulate node death by setting heartbeat older than 2-minute threshold
|
||||
oldHeartbeat := time.Now().UTC().Add(-3 * time.Minute)
|
||||
err = UpdateNodeHeartbeatDirectly(mockNodeID, 100, oldHeartbeat)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Trigger dead node detection
|
||||
err = GetRestoresScheduler().checkDeadNodesAndFailRestores()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify restore was failed with appropriate error message
|
||||
failedRestore, err := restoreRepository.FindByID(restore.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, restores_core.RestoreStatusFailed, failedRestore.Status)
|
||||
assert.NotNil(t, failedRestore.FailMessage)
|
||||
assert.Contains(t, *failedRestore.FailMessage, "node unavailability")
|
||||
|
||||
// Verify Valkey counter was decremented after restore failed
|
||||
stats, err = restoreNodesRegistry.GetRestoreNodesStats()
|
||||
assert.NoError(t, err)
|
||||
for _, stat := range stats {
|
||||
if stat.ID == mockNodeID {
|
||||
assert.Equal(t, 0, stat.ActiveRestores)
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_OnRestoreCompleted_TaskIsNotRestore_SkipsProcessing(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
var mockNodeID uuid.UUID
|
||||
|
||||
defer func() {
|
||||
backupRepo := backups_core.BackupRepository{}
|
||||
backups, _ := backupRepo.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
restoreRepo := restores_core.RestoreRepository{}
|
||||
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
// Clean up mock node
|
||||
if mockNodeID != uuid.Nil {
|
||||
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: mockNodeID})
|
||||
}
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
// Create a test backup
|
||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
||||
|
||||
// Register mock node
|
||||
mockNodeID = uuid.New()
|
||||
err := CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC())
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create restore and assign to the node
|
||||
restore := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
}
|
||||
err = restoreRepository.Save(restore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = GetRestoresScheduler().StartRestore(restore.ID, nil)
|
||||
assert.NoError(t, err)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Get initial state of the registry
|
||||
initialStats, err := restoreNodesRegistry.GetRestoreNodesStats()
|
||||
assert.NoError(t, err)
|
||||
var initialActiveTasks int
|
||||
for _, stat := range initialStats {
|
||||
if stat.ID == mockNodeID {
|
||||
initialActiveTasks = stat.ActiveRestores
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, initialActiveTasks, "Should have 1 active task")
|
||||
|
||||
// Call onRestoreCompleted with a random UUID (not a restore ID)
|
||||
nonRestoreTaskID := uuid.New()
|
||||
GetRestoresScheduler().onRestoreCompleted(mockNodeID, nonRestoreTaskID)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify: Active tasks counter should remain the same (not decremented)
|
||||
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
|
||||
assert.NoError(t, err)
|
||||
for _, stat := range stats {
|
||||
if stat.ID == mockNodeID {
|
||||
assert.Equal(t, initialActiveTasks, stat.ActiveRestores,
|
||||
"Active tasks should not change for non-restore task")
|
||||
}
|
||||
}
|
||||
|
||||
// Verify: restore should still be in progress (not modified)
|
||||
unchangedRestore, err := restoreRepository.FindByID(restore.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, restores_core.RestoreStatusInProgress, unchangedRestore.Status,
|
||||
"Restore status should not change for non-restore task completion")
|
||||
|
||||
// Verify: restoreToNodeRelations should still contain the node
|
||||
scheduler := GetRestoresScheduler()
|
||||
_, exists := scheduler.restoreToNodeRelations[mockNodeID]
|
||||
assert.True(t, exists, "Node should still be in restoreToNodeRelations")
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
|
||||
t.Run("Nodes with same throughput", func(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
node1ID := uuid.New()
|
||||
node2ID := uuid.New()
|
||||
node3ID := uuid.New()
|
||||
now := time.Now().UTC()
|
||||
|
||||
defer func() {
|
||||
// Clean up all mock nodes
|
||||
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: node1ID})
|
||||
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: node2ID})
|
||||
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: node3ID})
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
err := CreateMockNodeInRegistry(node1ID, 100, now)
|
||||
assert.NoError(t, err)
|
||||
err = CreateMockNodeInRegistry(node2ID, 100, now)
|
||||
assert.NoError(t, err)
|
||||
err = CreateMockNodeInRegistry(node3ID, 100, now)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for range 5 {
|
||||
err = restoreNodesRegistry.IncrementRestoresInProgress(node1ID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
for range 2 {
|
||||
err = restoreNodesRegistry.IncrementRestoresInProgress(node2ID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
for range 8 {
|
||||
err = restoreNodesRegistry.IncrementRestoresInProgress(node3ID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
leastBusyNodeID, err := GetRestoresScheduler().calculateLeastBusyNode()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, leastBusyNodeID)
|
||||
assert.Equal(t, node2ID, *leastBusyNodeID)
|
||||
})
|
||||
|
||||
t.Run("Nodes with different throughput", func(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
node100MBsID := uuid.New()
|
||||
node50MBsID := uuid.New()
|
||||
now := time.Now().UTC()
|
||||
|
||||
defer func() {
|
||||
// Clean up all mock nodes
|
||||
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: node100MBsID})
|
||||
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: node50MBsID})
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
err := CreateMockNodeInRegistry(node100MBsID, 100, now)
|
||||
assert.NoError(t, err)
|
||||
err = CreateMockNodeInRegistry(node50MBsID, 50, now)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for range 10 {
|
||||
err = restoreNodesRegistry.IncrementRestoresInProgress(node100MBsID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
err = restoreNodesRegistry.IncrementRestoresInProgress(node50MBsID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
leastBusyNodeID, err := GetRestoresScheduler().calculateLeastBusyNode()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, leastBusyNodeID)
|
||||
assert.Equal(t, node50MBsID, *leastBusyNodeID)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_FailRestoresInProgress_SchedulerStarts_UpdatesStatus(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backupRepo := backups_core.BackupRepository{}
|
||||
backups, _ := backupRepo.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
restoreRepo := restores_core.RestoreRepository{}
|
||||
|
||||
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
restores, _ = restoreRepo.FindByStatus(restores_core.RestoreStatusFailed)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
restores, _ = restoreRepo.FindByStatus(restores_core.RestoreStatusCompleted)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
// Create a test backup
|
||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
||||
|
||||
// Create two in-progress restores that should be failed on scheduler restart
|
||||
restore1 := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
CreatedAt: time.Now().UTC().Add(-30 * time.Minute),
|
||||
}
|
||||
err := restoreRepository.Save(restore1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
restore2 := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
CreatedAt: time.Now().UTC().Add(-15 * time.Minute),
|
||||
}
|
||||
err = restoreRepository.Save(restore2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create a completed restore to verify it's not affected by failRestoresInProgress
|
||||
completedRestore := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusCompleted,
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
}
|
||||
err = restoreRepository.Save(completedRestore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Trigger the scheduler's failRestoresInProgress logic
|
||||
// This should mark in-progress restores as failed
|
||||
err = GetRestoresScheduler().failRestoresInProgress()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify all restores exist and were processed correctly
|
||||
allRestores1, err := restoreRepository.FindByID(restore1.ID)
|
||||
assert.NoError(t, err)
|
||||
allRestores2, err := restoreRepository.FindByID(restore2.ID)
|
||||
assert.NoError(t, err)
|
||||
allRestores3, err := restoreRepository.FindByID(completedRestore.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var failedCount int
|
||||
var completedCount int
|
||||
|
||||
restoresToCheck := []*restores_core.Restore{allRestores1, allRestores2, allRestores3}
|
||||
for _, restore := range restoresToCheck {
|
||||
switch restore.Status {
|
||||
case restores_core.RestoreStatusFailed:
|
||||
failedCount++
|
||||
// Verify fail message indicates application restart
|
||||
assert.NotNil(t, restore.FailMessage)
|
||||
assert.Equal(t, "Restore failed due to application restart", *restore.FailMessage)
|
||||
case restores_core.RestoreStatusCompleted:
|
||||
completedCount++
|
||||
}
|
||||
}
|
||||
|
||||
// Verify correct number of restores in each state
|
||||
assert.Equal(t, 2, failedCount, "Should have 2 failed restores (originally in progress)")
|
||||
assert.Equal(t, 1, completedCount, "Should have 1 completed restore (unchanged)")
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_StartRestore_RestoreCompletes_DecrementsActiveTaskCount(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
// Start scheduler so it can handle task completions
|
||||
scheduler := CreateTestRestoresScheduler()
|
||||
schedulerCancel := StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
restorerNode := CreateTestRestorerNode()
|
||||
restorerNode.restoreBackupUsecase = &MockSuccessRestoreUsecase{}
|
||||
|
||||
cancel := StartRestorerNodeForTest(t, restorerNode)
|
||||
defer StopRestorerNodeForTest(t, cancel, restorerNode)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backupRepo := backups_core.BackupRepository{}
|
||||
backups, _ := backupRepo.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
restoreRepo := restores_core.RestoreRepository{}
|
||||
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusCompleted)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
// Create a test backup
|
||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
||||
|
||||
// Get initial active task count
|
||||
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
|
||||
assert.NoError(t, err)
|
||||
var initialActiveTasks int
|
||||
for _, stat := range stats {
|
||||
if stat.ID == restorerNode.nodeID {
|
||||
initialActiveTasks = stat.ActiveRestores
|
||||
break
|
||||
}
|
||||
}
|
||||
t.Logf("Initial active tasks: %d", initialActiveTasks)
|
||||
|
||||
// Create and start restore
|
||||
restore := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
}
|
||||
err = restoreRepository.Save(restore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = scheduler.StartRestore(restore.ID, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Wait for restore to complete
|
||||
WaitForRestoreCompletion(t, restore.ID, 10*time.Second)
|
||||
|
||||
// Verify restore was completed
|
||||
completedRestore, err := restoreRepository.FindByID(restore.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, completedRestore.Status)
|
||||
|
||||
// Wait for active task count to decrease
|
||||
decreased := WaitForActiveTasksDecrease(
|
||||
t,
|
||||
restorerNode.nodeID,
|
||||
initialActiveTasks+1,
|
||||
10*time.Second,
|
||||
)
|
||||
assert.True(t, decreased, "Active task count should have decreased after restore completion")
|
||||
|
||||
// Verify final active task count equals initial count
|
||||
finalStats, err := restoreNodesRegistry.GetRestoreNodesStats()
|
||||
assert.NoError(t, err)
|
||||
for _, stat := range finalStats {
|
||||
if stat.ID == restorerNode.nodeID {
|
||||
t.Logf("Final active tasks: %d", stat.ActiveRestores)
|
||||
assert.Equal(t, initialActiveTasks, stat.ActiveRestores,
|
||||
"Active task count should return to initial value after restore completion")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_StartRestore_RestoreFails_DecrementsActiveTaskCount(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
// Start scheduler so it can handle task completions
|
||||
scheduler := CreateTestRestoresScheduler()
|
||||
schedulerCancel := StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
restorerNode := CreateTestRestorerNode()
|
||||
restorerNode.restoreBackupUsecase = &MockFailedRestoreUsecase{}
|
||||
|
||||
cancel := StartRestorerNodeForTest(t, restorerNode)
|
||||
defer StopRestorerNodeForTest(t, cancel, restorerNode)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backupRepo := backups_core.BackupRepository{}
|
||||
backups, _ := backupRepo.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
restoreRepo := restores_core.RestoreRepository{}
|
||||
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusFailed)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
// Create a test backup
|
||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
||||
|
||||
// Get initial active task count
|
||||
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
|
||||
assert.NoError(t, err)
|
||||
var initialActiveTasks int
|
||||
for _, stat := range stats {
|
||||
if stat.ID == restorerNode.nodeID {
|
||||
initialActiveTasks = stat.ActiveRestores
|
||||
break
|
||||
}
|
||||
}
|
||||
t.Logf("Initial active tasks: %d", initialActiveTasks)
|
||||
|
||||
// Create and start restore
|
||||
restore := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
}
|
||||
err = restoreRepository.Save(restore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = scheduler.StartRestore(restore.ID, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Wait for restore to fail
|
||||
WaitForRestoreCompletion(t, restore.ID, 10*time.Second)
|
||||
|
||||
// Verify restore failed
|
||||
failedRestore, err := restoreRepository.FindByID(restore.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, restores_core.RestoreStatusFailed, failedRestore.Status)
|
||||
|
||||
// Wait for active task count to decrease
|
||||
decreased := WaitForActiveTasksDecrease(
|
||||
t,
|
||||
restorerNode.nodeID,
|
||||
initialActiveTasks+1,
|
||||
10*time.Second,
|
||||
)
|
||||
assert.True(t, decreased, "Active task count should have decreased after restore failure")
|
||||
|
||||
// Verify final active task count equals initial count
|
||||
finalStats, err := restoreNodesRegistry.GetRestoreNodesStats()
|
||||
assert.NoError(t, err)
|
||||
for _, stat := range finalStats {
|
||||
if stat.ID == restorerNode.nodeID {
|
||||
t.Logf("Final active tasks: %d", stat.ActiveRestores)
|
||||
assert.Equal(t, initialActiveTasks, stat.ActiveRestores,
|
||||
"Active task count should return to initial value after restore failure")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_StartRestore_CredentialsStoredEncryptedInCache(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
var mockNodeID uuid.UUID
|
||||
|
||||
defer func() {
|
||||
backupRepo := backups_core.BackupRepository{}
|
||||
backups, _ := backupRepo.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
restoreRepo := restores_core.RestoreRepository{}
|
||||
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
// Clean up mock node
|
||||
if mockNodeID != uuid.Nil {
|
||||
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: mockNodeID})
|
||||
}
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
// Create a test backup
|
||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
||||
|
||||
// Register mock node so scheduler can assign restore to it
|
||||
mockNodeID = uuid.New()
|
||||
err := CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC())
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create restore with plaintext credentials
|
||||
plaintextPassword := "test_password_123"
|
||||
restore := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
}
|
||||
err = restoreRepository.Save(restore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create PostgreSQL database credentials with plaintext password
|
||||
postgresDB := &postgresql.PostgresqlDatabase{
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "testuser",
|
||||
Password: plaintextPassword,
|
||||
Database: stringPtr("testdb"),
|
||||
Version: "16",
|
||||
}
|
||||
|
||||
// Encrypt password using FieldEncryptor (same as production flow)
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
err = postgresDB.EncryptSensitiveFields(database.ID, encryptor)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify password was encrypted (different from plaintext)
|
||||
assert.NotEqual(t, plaintextPassword, postgresDB.Password,
|
||||
"Password should be encrypted, not plaintext")
|
||||
|
||||
// Create cache with encrypted credentials
|
||||
dbCache := &RestoreDatabaseCache{
|
||||
PostgresqlDatabase: postgresDB,
|
||||
}
|
||||
|
||||
// Call StartRestore to cache credentials (do NOT start restore node)
|
||||
err = GetRestoresScheduler().StartRestore(restore.ID, dbCache)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Directly read from cache
|
||||
cachedData := restoreDatabaseCache.Get(restore.ID.String())
|
||||
assert.NotNil(t, cachedData, "Cache entry should exist")
|
||||
assert.NotNil(t, cachedData.PostgresqlDatabase, "PostgreSQL credentials should be cached")
|
||||
|
||||
// Verify password in cache is encrypted (not plaintext)
|
||||
assert.NotEqual(t, plaintextPassword, cachedData.PostgresqlDatabase.Password,
|
||||
"Cached password should be encrypted, not plaintext")
|
||||
assert.Equal(t, postgresDB.Password, cachedData.PostgresqlDatabase.Password,
|
||||
"Cached password should match the encrypted version")
|
||||
|
||||
// Verify other fields are present
|
||||
assert.Equal(t, config.GetEnv().TestLocalhost, cachedData.PostgresqlDatabase.Host)
|
||||
assert.Equal(t, 5432, cachedData.PostgresqlDatabase.Port)
|
||||
assert.Equal(t, "testuser", cachedData.PostgresqlDatabase.Username)
|
||||
assert.Equal(t, "testdb", *cachedData.PostgresqlDatabase.Database)
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_StartRestore_CredentialsRemovedAfterRestoreStarts(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
// Start scheduler so it can handle task assignments
|
||||
scheduler := CreateTestRestoresScheduler()
|
||||
schedulerCancel := StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
// Create mock restorer node with credential capture usecase
|
||||
restorerNode := CreateTestRestorerNode()
|
||||
calledChan := make(chan *databases.Database, 1)
|
||||
restorerNode.restoreBackupUsecase = &MockCaptureCredentialsRestoreUsecase{
|
||||
CalledChan: calledChan,
|
||||
ShouldSucceed: true,
|
||||
}
|
||||
|
||||
cancel := StartRestorerNodeForTest(t, restorerNode)
|
||||
defer StopRestorerNodeForTest(t, cancel, restorerNode)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backupRepo := backups_core.BackupRepository{}
|
||||
backups, _ := backupRepo.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
restoreRepo := restores_core.RestoreRepository{}
|
||||
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusCompleted)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
// Create a test backup
|
||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
||||
|
||||
// Create restore with credentials
|
||||
plaintextPassword := "test_password_456"
|
||||
restore := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
}
|
||||
err := restoreRepository.Save(restore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create PostgreSQL database credentials
|
||||
// Database field is nil to avoid PopulateDbData trying to connect
|
||||
postgresDB := &postgresql.PostgresqlDatabase{
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "testuser",
|
||||
Password: plaintextPassword,
|
||||
Database: nil,
|
||||
Version: "16",
|
||||
}
|
||||
|
||||
// Encrypt password (same as production flow)
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
err = postgresDB.EncryptSensitiveFields(database.ID, encryptor)
|
||||
assert.NoError(t, err)
|
||||
|
||||
encryptedPassword := postgresDB.Password
|
||||
|
||||
// Create cache with encrypted credentials
|
||||
dbCache := &RestoreDatabaseCache{
|
||||
PostgresqlDatabase: postgresDB,
|
||||
}
|
||||
|
||||
// Call StartRestore to cache credentials and trigger restore
|
||||
err = scheduler.StartRestore(restore.ID, dbCache)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Wait for mock usecase to be called (with timeout)
|
||||
var capturedDB *databases.Database
|
||||
select {
|
||||
case capturedDB = <-calledChan:
|
||||
t.Log("Mock usecase was called, credentials captured")
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Fatal("Timeout waiting for mock usecase to be called")
|
||||
}
|
||||
|
||||
// Verify cache is empty after restore starts (credentials were deleted)
|
||||
cacheAfterExecution := restoreDatabaseCache.Get(restore.ID.String())
|
||||
assert.Nil(t, cacheAfterExecution, "Cache should be empty after restore execution starts")
|
||||
|
||||
// Verify mock received valid credentials
|
||||
assert.NotNil(t, capturedDB, "Captured database should not be nil")
|
||||
assert.NotNil(t, capturedDB.Postgresql, "PostgreSQL credentials should be provided to usecase")
|
||||
assert.Equal(t, config.GetEnv().TestLocalhost, capturedDB.Postgresql.Host)
|
||||
assert.Equal(t, 5432, capturedDB.Postgresql.Port)
|
||||
assert.Equal(t, "testuser", capturedDB.Postgresql.Username)
|
||||
assert.NotEmpty(t, capturedDB.Postgresql.Password, "Password should be provided to usecase")
|
||||
|
||||
// Note: Password at this point may still be encrypted because PopulateDbData
|
||||
// is called after the mock captures it. The important thing is that credentials
|
||||
// were provided to the usecase despite cache being deleted.
|
||||
t.Logf("Encrypted password in cache: %s", encryptedPassword)
|
||||
t.Logf("Password received by usecase: %s", capturedDB.Postgresql.Password)
|
||||
|
||||
// Wait for restore to complete
|
||||
WaitForRestoreCompletion(t, restore.ID, 10*time.Second)
|
||||
|
||||
// Verify restore was completed
|
||||
completedRestore, err := restoreRepository.FindByID(restore.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, completedRestore.Status)
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
342
backend/internal/features/restores/restoring/testing.go
Normal file
342
backend/internal/features/restores/restoring/testing.go
Normal file
@@ -0,0 +1,342 @@
|
||||
package restoring
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/restores/usecases"
|
||||
"databasus-backend/internal/features/storages"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
func CreateTestRouter() *gin.Engine {
|
||||
router := workspaces_testing.CreateTestRouter(
|
||||
workspaces_controllers.GetWorkspaceController(),
|
||||
workspaces_controllers.GetMembershipController(),
|
||||
databases.GetDatabaseController(),
|
||||
backups_config.GetBackupConfigController(),
|
||||
)
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
func CreateTestRestorerNode() *RestorerNode {
|
||||
return &RestorerNode{
|
||||
nodeID: uuid.New(),
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
backupService: backups.GetBackupService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
restoreRepository: restoreRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
restoreNodesRegistry: restoreNodesRegistry,
|
||||
logger: logger.GetLogger(),
|
||||
restoreBackupUsecase: usecases.GetRestoreBackupUsecase(),
|
||||
cacheUtil: restoreDatabaseCache,
|
||||
restoreCancelManager: tasks_cancellation.GetTaskCancelManager(),
|
||||
lastHeartbeat: time.Time{},
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateTestRestorerNodeWithUsecase(usecase restores_core.RestoreBackupUsecase) *RestorerNode {
|
||||
return &RestorerNode{
|
||||
nodeID: uuid.New(),
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
backupService: backups.GetBackupService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
restoreRepository: restoreRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
restoreNodesRegistry: restoreNodesRegistry,
|
||||
logger: logger.GetLogger(),
|
||||
restoreBackupUsecase: usecase,
|
||||
cacheUtil: restoreDatabaseCache,
|
||||
restoreCancelManager: tasks_cancellation.GetTaskCancelManager(),
|
||||
lastHeartbeat: time.Time{},
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateTestRestoresScheduler() *RestoresScheduler {
|
||||
return &RestoresScheduler{
|
||||
restoreRepository,
|
||||
backups.GetBackupService(),
|
||||
storages.GetStorageService(),
|
||||
backups_config.GetBackupConfigService(),
|
||||
restoreNodesRegistry,
|
||||
time.Now().UTC(),
|
||||
logger.GetLogger(),
|
||||
make(map[uuid.UUID]RestoreToNodeRelation),
|
||||
restorerNode,
|
||||
restoreDatabaseCache,
|
||||
uuid.Nil,
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
// WaitForRestoreCompletion waits for a restore to be completed (or failed)
|
||||
func WaitForRestoreCompletion(
|
||||
t *testing.T,
|
||||
restoreID uuid.UUID,
|
||||
timeout time.Duration,
|
||||
) {
|
||||
deadline := time.Now().UTC().Add(timeout)
|
||||
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
restore, err := restoreRepository.FindByID(restoreID)
|
||||
if err != nil {
|
||||
t.Logf("WaitForRestoreCompletion: error finding restore: %v", err)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
t.Logf("WaitForRestoreCompletion: restore status: %s", restore.Status)
|
||||
|
||||
if restore.Status == restores_core.RestoreStatusCompleted ||
|
||||
restore.Status == restores_core.RestoreStatusFailed {
|
||||
t.Logf(
|
||||
"WaitForRestoreCompletion: restore finished with status %s",
|
||||
restore.Status,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Logf("WaitForRestoreCompletion: timeout waiting for restore to complete")
|
||||
}
|
||||
|
||||
// StartRestorerNodeForTest starts a RestorerNode in a goroutine for testing.
|
||||
// The node registers itself in the registry and subscribes to restore assignments.
|
||||
// Returns a context cancel function that should be deferred to stop the node.
|
||||
func StartRestorerNodeForTest(t *testing.T, restorerNode *RestorerNode) context.CancelFunc {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
restorerNode.Run(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Poll registry for node presence instead of fixed sleep
|
||||
deadline := time.Now().UTC().Add(5 * time.Second)
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
nodes, err := restoreNodesRegistry.GetAvailableNodes()
|
||||
if err == nil {
|
||||
for _, node := range nodes {
|
||||
if node.ID == restorerNode.nodeID {
|
||||
t.Logf("RestorerNode registered in registry: %s", restorerNode.nodeID)
|
||||
|
||||
return func() {
|
||||
cancel()
|
||||
select {
|
||||
case <-done:
|
||||
t.Log("RestorerNode stopped gracefully")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Log("RestorerNode stop timeout")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Fatalf("RestorerNode failed to register in registry within timeout")
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartSchedulerForTest starts the RestoresScheduler in a goroutine for testing.
|
||||
// The scheduler subscribes to task completions and manages restore lifecycle.
|
||||
// Returns a context cancel function that should be deferred to stop the scheduler.
|
||||
func StartSchedulerForTest(t *testing.T, scheduler *RestoresScheduler) context.CancelFunc {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
scheduler.Run(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Give scheduler time to subscribe to completions
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
t.Log("RestoresScheduler started")
|
||||
|
||||
return func() {
|
||||
cancel()
|
||||
select {
|
||||
case <-done:
|
||||
t.Log("RestoresScheduler stopped gracefully")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Log("RestoresScheduler stop timeout")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StopRestorerNodeForTest stops the RestorerNode by canceling its context.
|
||||
// It waits for the node to unregister from the registry.
|
||||
func StopRestorerNodeForTest(t *testing.T, cancel context.CancelFunc, restorerNode *RestorerNode) {
|
||||
cancel()
|
||||
|
||||
// Wait for node to unregister from registry
|
||||
deadline := time.Now().UTC().Add(2 * time.Second)
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
nodes, err := restoreNodesRegistry.GetAvailableNodes()
|
||||
if err == nil {
|
||||
found := false
|
||||
for _, node := range nodes {
|
||||
if node.ID == restorerNode.nodeID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Logf("RestorerNode unregistered from registry: %s", restorerNode.nodeID)
|
||||
return
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Logf("RestorerNode stop completed for %s", restorerNode.nodeID)
|
||||
}
|
||||
|
||||
func CreateMockNodeInRegistry(nodeID uuid.UUID, throughputMBs int, lastHeartbeat time.Time) error {
|
||||
restoreNode := RestoreNode{
|
||||
ID: nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: lastHeartbeat,
|
||||
}
|
||||
|
||||
return restoreNodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, restoreNode)
|
||||
}
|
||||
|
||||
func UpdateNodeHeartbeatDirectly(
|
||||
nodeID uuid.UUID,
|
||||
throughputMBs int,
|
||||
lastHeartbeat time.Time,
|
||||
) error {
|
||||
restoreNode := RestoreNode{
|
||||
ID: nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: lastHeartbeat,
|
||||
}
|
||||
|
||||
return restoreNodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, restoreNode)
|
||||
}
|
||||
|
||||
func GetNodeFromRegistry(nodeID uuid.UUID) (*RestoreNode, error) {
|
||||
nodes, err := restoreNodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
if node.ID == nodeID {
|
||||
return &node, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("node not found")
|
||||
}
|
||||
|
||||
// WaitForActiveTasksDecrease waits for the active task count to decrease below the initial count.
|
||||
// It polls the registry every 500ms until the count decreases or the timeout is reached.
|
||||
// Returns true if the count decreased, false if timeout was reached.
|
||||
func WaitForActiveTasksDecrease(
|
||||
t *testing.T,
|
||||
nodeID uuid.UUID,
|
||||
initialCount int,
|
||||
timeout time.Duration,
|
||||
) bool {
|
||||
deadline := time.Now().UTC().Add(timeout)
|
||||
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
|
||||
if err != nil {
|
||||
t.Logf("WaitForActiveTasksDecrease: error getting node stats: %v", err)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, stat := range stats {
|
||||
if stat.ID == nodeID {
|
||||
t.Logf(
|
||||
"WaitForActiveTasksDecrease: current active tasks = %d (initial = %d)",
|
||||
stat.ActiveRestores,
|
||||
initialCount,
|
||||
)
|
||||
if stat.ActiveRestores < initialCount {
|
||||
t.Logf(
|
||||
"WaitForActiveTasksDecrease: active tasks decreased from %d to %d",
|
||||
initialCount,
|
||||
stat.ActiveRestores,
|
||||
)
|
||||
return true
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Logf("WaitForActiveTasksDecrease: timeout waiting for active tasks to decrease")
|
||||
return false
|
||||
}
|
||||
|
||||
// CreateTestRestore creates a test restore with the given backup and status
|
||||
func CreateTestRestore(
|
||||
t *testing.T,
|
||||
backup *backups_core.Backup,
|
||||
status restores_core.RestoreStatus,
|
||||
) *restores_core.Restore {
|
||||
restore := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: status,
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "test",
|
||||
Password: "test",
|
||||
Database: stringPtr("testdb"),
|
||||
Version: "16",
|
||||
},
|
||||
}
|
||||
|
||||
err := restoreRepository.Save(restore)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test restore: %v", err)
|
||||
}
|
||||
|
||||
return restore
|
||||
}
|
||||
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user