mirror of
https://github.com/databasus/databasus.git
synced 2026-04-06 00:32:03 +02:00
Compare commits
41 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d4763f26b2 | ||
|
|
0e389ba16b | ||
|
|
594a3294c6 | ||
|
|
4e4a323cf1 | ||
|
|
7d9ecf697b | ||
|
|
755c420157 | ||
|
|
ff73627287 | ||
|
|
9c9ab00ace | ||
|
|
7366e21a1a | ||
|
|
a327d1aa57 | ||
|
|
f152b16ea3 | ||
|
|
85dbe80d3d | ||
|
|
edf4028fd1 | ||
|
|
8d85c45a90 | ||
|
|
d9c176d19a | ||
|
|
7a6f72a456 | ||
|
|
9a1471b88b | ||
|
|
386ea1d708 | ||
|
|
a4b23936ee | ||
|
|
b36aa9d48b | ||
|
|
13cb8e5bd2 | ||
|
|
2db4b6e075 | ||
|
|
f2b0b2bf1f | ||
|
|
7142ce295e | ||
|
|
04621b9b2d | ||
|
|
bd329a68cf | ||
|
|
f957abc9db | ||
|
|
c0fd6be1a9 | ||
|
|
c39bd34d5e | ||
|
|
27bec15a29 | ||
|
|
d98baa0656 | ||
|
|
4344f5ea5e | ||
|
|
7c6afa5b88 | ||
|
|
dbac799e1b | ||
|
|
7ee3817089 | ||
|
|
bae6f7f007 | ||
|
|
55dc087ddd | ||
|
|
c94d0db637 | ||
|
|
a1adef2261 | ||
|
|
4602dc3f88 | ||
|
|
cbbfc5ea8f |
310
.github/workflows/ci-release.yml
vendored
310
.github/workflows/ci-release.yml
vendored
@@ -9,25 +9,26 @@ on:
|
||||
|
||||
jobs:
|
||||
lint-backend:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
container:
|
||||
image: golang:1.24.9
|
||||
volumes:
|
||||
- /runner-cache/go-pkg:/go/pkg/mod
|
||||
- /runner-cache/go-build:/root/.cache/go-build
|
||||
- /runner-cache/golangci-lint:/root/.cache/golangci-lint
|
||||
- /runner-cache/apt-archives:/var/cache/apt/archives
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.24.9"
|
||||
- name: Configure Git for container
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
~/.cache/go-build
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('backend/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-
|
||||
- name: Download Go modules
|
||||
run: |
|
||||
cd backend
|
||||
go mod download
|
||||
|
||||
- name: Install golangci-lint
|
||||
run: |
|
||||
@@ -63,8 +64,6 @@ jobs:
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "20"
|
||||
cache: "npm"
|
||||
cache-dependency-path: frontend/package-lock.json
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
@@ -93,8 +92,6 @@ jobs:
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "20"
|
||||
cache: "npm"
|
||||
cache-dependency-path: frontend/package-lock.json
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
@@ -107,44 +104,32 @@ jobs:
|
||||
npm run test
|
||||
|
||||
test-backend:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
needs: [lint-backend]
|
||||
container:
|
||||
image: golang:1.24.9
|
||||
options: --privileged -v /var/run/docker.sock:/var/run/docker.sock --add-host=host.docker.internal:host-gateway
|
||||
volumes:
|
||||
- /runner-cache/go-pkg:/go/pkg/mod
|
||||
- /runner-cache/go-build:/root/.cache/go-build
|
||||
- /runner-cache/apt-archives:/var/cache/apt/archives
|
||||
steps:
|
||||
- name: Free up disk space
|
||||
- name: Install Docker CLI
|
||||
run: |
|
||||
echo "Disk space before cleanup:"
|
||||
df -h
|
||||
# Remove unnecessary pre-installed software
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo rm -rf /usr/local/share/boost
|
||||
sudo rm -rf /usr/share/swift
|
||||
# Clean apt cache
|
||||
sudo apt-get clean
|
||||
# Clean docker images (if any pre-installed)
|
||||
docker system prune -af --volumes || true
|
||||
echo "Disk space after cleanup:"
|
||||
df -h
|
||||
apt-get update -qq
|
||||
apt-get install -y -qq docker.io docker-compose netcat-openbsd wget
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.24.9"
|
||||
- name: Configure Git for container
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
~/.cache/go-build
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('backend/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-
|
||||
- name: Download Go modules
|
||||
run: |
|
||||
cd backend
|
||||
go mod download
|
||||
|
||||
- name: Create .env file for testing
|
||||
run: |
|
||||
@@ -156,14 +141,16 @@ jobs:
|
||||
DEV_DB_PASSWORD=Q1234567
|
||||
#app
|
||||
ENV_MODE=development
|
||||
# db
|
||||
DATABASE_DSN=host=localhost user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
|
||||
DATABASE_URL=postgres://postgres:Q1234567@localhost:5437/databasus?sslmode=disable
|
||||
# db - using 172.17.0.1 to access host from container
|
||||
DATABASE_DSN=host=172.17.0.1 user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
|
||||
DATABASE_URL=postgres://postgres:Q1234567@172.17.0.1:5437/databasus?sslmode=disable
|
||||
# migrations
|
||||
GOOSE_DRIVER=postgres
|
||||
GOOSE_DBSTRING=postgres://postgres:Q1234567@localhost:5437/databasus?sslmode=disable
|
||||
GOOSE_DBSTRING=postgres://postgres:Q1234567@172.17.0.1:5437/databasus?sslmode=disable
|
||||
GOOSE_MIGRATION_DIR=./migrations
|
||||
# testing
|
||||
# testing
|
||||
TEST_LOCALHOST=172.17.0.1
|
||||
IS_SKIP_EXTERNAL_RESOURCES_TESTS=true
|
||||
# to get Google Drive env variables: add storage in UI and copy data from added storage here
|
||||
TEST_GOOGLE_DRIVE_CLIENT_ID=${{ secrets.TEST_GOOGLE_DRIVE_CLIENT_ID }}
|
||||
TEST_GOOGLE_DRIVE_CLIENT_SECRET=${{ secrets.TEST_GOOGLE_DRIVE_CLIENT_SECRET }}
|
||||
@@ -221,12 +208,14 @@ jobs:
|
||||
TEST_MONGODB_60_PORT=27060
|
||||
TEST_MONGODB_70_PORT=27070
|
||||
TEST_MONGODB_82_PORT=27082
|
||||
# Valkey (cache)
|
||||
VALKEY_HOST=localhost
|
||||
# Valkey (cache) - using 172.17.0.1
|
||||
VALKEY_HOST=172.17.0.1
|
||||
VALKEY_PORT=6379
|
||||
VALKEY_USERNAME=
|
||||
VALKEY_PASSWORD=
|
||||
VALKEY_IS_SSL=false
|
||||
# Host for test databases (container -> host)
|
||||
TEST_DB_HOST=172.17.0.1
|
||||
EOF
|
||||
|
||||
- name: Start test containers
|
||||
@@ -244,25 +233,25 @@ jobs:
|
||||
timeout 60 bash -c 'until docker exec dev-valkey valkey-cli ping 2>/dev/null | grep -q PONG; do sleep 2; done'
|
||||
echo "Valkey is ready!"
|
||||
|
||||
# Wait for test databases
|
||||
timeout 60 bash -c 'until nc -z localhost 5000; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5001; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5002; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5003; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5004; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5005; do sleep 2; done'
|
||||
# Wait for test databases (using 172.17.0.1 from container)
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5000; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5001; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5002; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5003; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5004; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5005; do sleep 2; done'
|
||||
|
||||
# Wait for MinIO
|
||||
timeout 60 bash -c 'until nc -z localhost 9000; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 9000; do sleep 2; done'
|
||||
|
||||
# Wait for Azurite
|
||||
timeout 60 bash -c 'until nc -z localhost 10000; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 10000; do sleep 2; done'
|
||||
|
||||
# Wait for FTP
|
||||
timeout 60 bash -c 'until nc -z localhost 7007; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 7007; do sleep 2; done'
|
||||
|
||||
# Wait for SFTP
|
||||
timeout 60 bash -c 'until nc -z localhost 7008; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 7008; do sleep 2; done'
|
||||
|
||||
# Wait for MySQL containers
|
||||
echo "Waiting for MySQL 5.7..."
|
||||
@@ -321,67 +310,66 @@ jobs:
|
||||
mkdir -p databasus-data/backups
|
||||
mkdir -p databasus-data/temp
|
||||
|
||||
- name: Cache PostgreSQL client tools
|
||||
id: cache-postgres
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: /usr/lib/postgresql
|
||||
key: postgres-clients-12-18-v1
|
||||
|
||||
- name: Cache MySQL client tools
|
||||
id: cache-mysql
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: backend/tools/mysql
|
||||
key: mysql-clients-57-80-84-9-v1
|
||||
|
||||
- name: Cache MariaDB client tools
|
||||
id: cache-mariadb
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: backend/tools/mariadb
|
||||
key: mariadb-clients-106-121-v1
|
||||
|
||||
- name: Cache MongoDB Database Tools
|
||||
id: cache-mongodb
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: backend/tools/mongodb
|
||||
key: mongodb-database-tools-100.10.0-v1
|
||||
|
||||
- name: Install MySQL dependencies
|
||||
- name: Install database client dependencies
|
||||
run: |
|
||||
sudo apt-get update -qq
|
||||
sudo apt-get install -y -qq libncurses6
|
||||
sudo ln -sf /usr/lib/x86_64-linux-gnu/libncurses.so.6 /usr/lib/x86_64-linux-gnu/libncurses.so.5
|
||||
sudo ln -sf /usr/lib/x86_64-linux-gnu/libtinfo.so.6 /usr/lib/x86_64-linux-gnu/libtinfo.so.5
|
||||
apt-get update -qq
|
||||
apt-get install -y -qq libncurses6 libpq5
|
||||
ln -sf /usr/lib/x86_64-linux-gnu/libncurses.so.6 /usr/lib/x86_64-linux-gnu/libncurses.so.5 || true
|
||||
ln -sf /usr/lib/x86_64-linux-gnu/libtinfo.so.6 /usr/lib/x86_64-linux-gnu/libtinfo.so.5 || true
|
||||
|
||||
- name: Install PostgreSQL, MySQL, MariaDB and MongoDB client tools
|
||||
if: steps.cache-postgres.outputs.cache-hit != 'true' || steps.cache-mysql.outputs.cache-hit != 'true' || steps.cache-mariadb.outputs.cache-hit != 'true' || steps.cache-mongodb.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
chmod +x backend/tools/download_linux.sh
|
||||
cd backend/tools
|
||||
./download_linux.sh
|
||||
|
||||
- name: Setup PostgreSQL symlinks (when using cache)
|
||||
if: steps.cache-postgres.outputs.cache-hit == 'true'
|
||||
- name: Setup PostgreSQL, MySQL and MariaDB client tools from pre-built assets
|
||||
run: |
|
||||
cd backend/tools
|
||||
mkdir -p postgresql
|
||||
|
||||
# Create directory structure
|
||||
mkdir -p postgresql mysql mariadb mongodb/bin
|
||||
|
||||
# Copy PostgreSQL client tools (12-18) from pre-built assets
|
||||
for version in 12 13 14 15 16 17 18; do
|
||||
version_dir="postgresql/postgresql-$version"
|
||||
mkdir -p "$version_dir/bin"
|
||||
pg_bin_dir="/usr/lib/postgresql/$version/bin"
|
||||
if [ -d "$pg_bin_dir" ]; then
|
||||
ln -sf "$pg_bin_dir/pg_dump" "$version_dir/bin/pg_dump"
|
||||
ln -sf "$pg_bin_dir/pg_dumpall" "$version_dir/bin/pg_dumpall"
|
||||
ln -sf "$pg_bin_dir/psql" "$version_dir/bin/psql"
|
||||
ln -sf "$pg_bin_dir/pg_restore" "$version_dir/bin/pg_restore"
|
||||
ln -sf "$pg_bin_dir/createdb" "$version_dir/bin/createdb"
|
||||
ln -sf "$pg_bin_dir/dropdb" "$version_dir/bin/dropdb"
|
||||
fi
|
||||
mkdir -p postgresql/postgresql-$version
|
||||
cp -r ../../assets/tools/x64/postgresql/postgresql-$version/bin postgresql/postgresql-$version/
|
||||
done
|
||||
|
||||
# Copy MySQL client tools (5.7, 8.0, 8.4, 9) from pre-built assets
|
||||
for version in 5.7 8.0 8.4 9; do
|
||||
mkdir -p mysql/mysql-$version
|
||||
cp -r ../../assets/tools/x64/mysql/mysql-$version/bin mysql/mysql-$version/
|
||||
done
|
||||
|
||||
# Copy MariaDB client tools (10.6, 12.1) from pre-built assets
|
||||
for version in 10.6 12.1; do
|
||||
mkdir -p mariadb/mariadb-$version
|
||||
cp -r ../../assets/tools/x64/mariadb/mariadb-$version/bin mariadb/mariadb-$version/
|
||||
done
|
||||
|
||||
# Make all binaries executable
|
||||
chmod +x postgresql/*/bin/*
|
||||
chmod +x mysql/*/bin/*
|
||||
chmod +x mariadb/*/bin/*
|
||||
|
||||
echo "Pre-built client tools setup complete"
|
||||
|
||||
- name: Install MongoDB Database Tools
|
||||
run: |
|
||||
cd backend/tools
|
||||
|
||||
# MongoDB Database Tools must be downloaded (not in pre-built assets)
|
||||
# They are backward compatible - single version supports all servers (4.0-8.0)
|
||||
MONGODB_TOOLS_URL="https://fastdl.mongodb.org/tools/db/mongodb-database-tools-debian12-x86_64-100.10.0.deb"
|
||||
|
||||
echo "Downloading MongoDB Database Tools..."
|
||||
wget -q "$MONGODB_TOOLS_URL" -O /tmp/mongodb-database-tools.deb
|
||||
|
||||
echo "Installing MongoDB Database Tools..."
|
||||
dpkg -i /tmp/mongodb-database-tools.deb || apt-get install -f -y --no-install-recommends
|
||||
|
||||
# Create symlinks to tools directory
|
||||
ln -sf /usr/bin/mongodump mongodb/bin/mongodump
|
||||
ln -sf /usr/bin/mongorestore mongodb/bin/mongorestore
|
||||
|
||||
rm -f /tmp/mongodb-database-tools.deb
|
||||
echo "MongoDB Database Tools installed successfully"
|
||||
|
||||
- name: Verify MariaDB client tools exist
|
||||
run: |
|
||||
cd backend/tools
|
||||
@@ -426,10 +414,28 @@ jobs:
|
||||
if: always()
|
||||
run: |
|
||||
cd backend
|
||||
# Stop and remove containers (keeping images for next run)
|
||||
docker compose -f docker-compose.yml.example down -v
|
||||
|
||||
# Clean up all data directories created by docker-compose
|
||||
echo "Cleaning up data directories..."
|
||||
rm -rf pgdata || true
|
||||
rm -rf valkey-data || true
|
||||
rm -rf mysqldata || true
|
||||
rm -rf mariadbdata || true
|
||||
rm -rf temp/nas || true
|
||||
rm -rf databasus-data || true
|
||||
|
||||
# Also clean root-level databasus-data if exists
|
||||
cd ..
|
||||
rm -rf databasus-data || true
|
||||
|
||||
echo "Cleanup complete"
|
||||
|
||||
determine-version:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
container:
|
||||
image: node:20
|
||||
needs: [test-backend, test-frontend]
|
||||
if: ${{ github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message, '[skip-release]') }}
|
||||
outputs:
|
||||
@@ -442,10 +448,9 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "20"
|
||||
- name: Configure Git for container
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Install semver
|
||||
run: npm install -g semver
|
||||
@@ -459,6 +464,7 @@ jobs:
|
||||
|
||||
- name: Analyze commits and determine version bump
|
||||
id: version_bump
|
||||
shell: bash
|
||||
run: |
|
||||
CURRENT_VERSION="${{ steps.current_version.outputs.current_version }}"
|
||||
LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
|
||||
@@ -478,7 +484,7 @@ jobs:
|
||||
HAS_FIX=false
|
||||
HAS_BREAKING=false
|
||||
|
||||
# Analyze each commit
|
||||
# Analyze each commit - USE PROCESS SUBSTITUTION to avoid subshell variable scope issues
|
||||
while IFS= read -r commit; do
|
||||
if [[ "$commit" =~ ^FEATURE ]]; then
|
||||
HAS_FEATURE=true
|
||||
@@ -496,7 +502,7 @@ jobs:
|
||||
HAS_BREAKING=true
|
||||
echo "Found BREAKING CHANGE: $commit"
|
||||
fi
|
||||
done <<< "$COMMITS"
|
||||
done < <(printf '%s\n' "$COMMITS")
|
||||
|
||||
# Determine version bump
|
||||
if [ "$HAS_BREAKING" = true ]; then
|
||||
@@ -522,10 +528,15 @@ jobs:
|
||||
fi
|
||||
|
||||
build-only:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
needs: [test-backend, test-frontend]
|
||||
if: ${{ github.ref == 'refs/heads/main' && contains(github.event.head_commit.message, '[skip-release]') }}
|
||||
steps:
|
||||
- name: Clean workspace
|
||||
run: |
|
||||
sudo rm -rf "$GITHUB_WORKSPACE"/* || true
|
||||
sudo rm -rf "$GITHUB_WORKSPACE"/.* || true
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -554,12 +565,17 @@ jobs:
|
||||
databasus/databasus:${{ github.sha }}
|
||||
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
needs: [determine-version]
|
||||
if: ${{ needs.determine-version.outputs.should_release == 'true' }}
|
||||
permissions:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Clean workspace
|
||||
run: |
|
||||
sudo rm -rf "$GITHUB_WORKSPACE"/* || true
|
||||
sudo rm -rf "$GITHUB_WORKSPACE"/.* || true
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -589,21 +605,33 @@ jobs:
|
||||
databasus/databasus:${{ github.sha }}
|
||||
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
container:
|
||||
image: node:20
|
||||
needs: [determine-version, build-and-push]
|
||||
if: ${{ needs.determine-version.outputs.should_release == 'true' }}
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Clean workspace
|
||||
run: |
|
||||
rm -rf "$GITHUB_WORKSPACE"/* || true
|
||||
rm -rf "$GITHUB_WORKSPACE"/.* || true
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Configure Git for container
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Generate changelog
|
||||
id: changelog
|
||||
shell: bash
|
||||
run: |
|
||||
NEW_VERSION="${{ needs.determine-version.outputs.new_version }}"
|
||||
LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
|
||||
@@ -623,6 +651,7 @@ jobs:
|
||||
FIXES=""
|
||||
REFACTORS=""
|
||||
|
||||
# USE PROCESS SUBSTITUTION to avoid subshell variable scope issues
|
||||
while IFS= read -r line; do
|
||||
if [ -n "$line" ]; then
|
||||
COMMIT_MSG=$(echo "$line" | cut -d'|' -f1)
|
||||
@@ -656,7 +685,7 @@ jobs:
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
done <<< "$COMMITS"
|
||||
done < <(printf '%s\n' "$COMMITS")
|
||||
|
||||
# Build changelog sections
|
||||
if [ -n "$FEATURES" ]; then
|
||||
@@ -695,16 +724,33 @@ jobs:
|
||||
prerelease: false
|
||||
|
||||
publish-helm-chart:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
container:
|
||||
image: alpine:3.19
|
||||
volumes:
|
||||
- /runner-cache/apk-cache:/etc/apk/cache
|
||||
needs: [determine-version, build-and-push]
|
||||
if: ${{ needs.determine-version.outputs.should_release == 'true' }}
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
steps:
|
||||
- name: Clean workspace
|
||||
run: |
|
||||
rm -rf "$GITHUB_WORKSPACE"/* || true
|
||||
rm -rf "$GITHUB_WORKSPACE"/.* || true
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apk add --no-cache git bash curl
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Configure Git for container
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@v4
|
||||
with:
|
||||
|
||||
@@ -1,152 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
Always place private methods to the bottom of file
|
||||
|
||||
**This rule applies to ALL Go files including tests, services, controllers, repositories, etc.**
|
||||
|
||||
In Go, exported (public) functions/methods start with uppercase letters, while unexported (private) ones start with lowercase letters.
|
||||
|
||||
## Structure Order:
|
||||
|
||||
1. Type definitions and constants
|
||||
2. Public methods/functions (uppercase)
|
||||
3. Private methods/functions (lowercase)
|
||||
|
||||
## Examples:
|
||||
|
||||
### Service with methods:
|
||||
|
||||
```go
|
||||
type UserService struct {
|
||||
repository *UserRepository
|
||||
}
|
||||
|
||||
// Public methods first
|
||||
func (s *UserService) CreateUser(user *User) error {
|
||||
if err := s.validateUser(user); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.repository.Save(user)
|
||||
}
|
||||
|
||||
func (s *UserService) GetUser(id uuid.UUID) (*User, error) {
|
||||
return s.repository.FindByID(id)
|
||||
}
|
||||
|
||||
// Private methods at the bottom
|
||||
func (s *UserService) validateUser(user *User) error {
|
||||
if user.Name == "" {
|
||||
return errors.New("name is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
### Package-level functions:
|
||||
|
||||
```go
|
||||
package utils
|
||||
|
||||
// Public functions first
|
||||
func ProcessData(data []byte) (Result, error) {
|
||||
cleaned := sanitizeInput(data)
|
||||
return parseData(cleaned)
|
||||
}
|
||||
|
||||
func ValidateInput(input string) bool {
|
||||
return isValidFormat(input) && checkLength(input)
|
||||
}
|
||||
|
||||
// Private functions at the bottom
|
||||
func sanitizeInput(data []byte) []byte {
|
||||
// implementation
|
||||
}
|
||||
|
||||
func parseData(data []byte) (Result, error) {
|
||||
// implementation
|
||||
}
|
||||
|
||||
func isValidFormat(input string) bool {
|
||||
// implementation
|
||||
}
|
||||
|
||||
func checkLength(input string) bool {
|
||||
// implementation
|
||||
}
|
||||
```
|
||||
|
||||
### Test files:
|
||||
|
||||
```go
|
||||
package user_test
|
||||
|
||||
// Public test functions first
|
||||
func Test_CreateUser_ValidInput_UserCreated(t *testing.T) {
|
||||
user := createTestUser()
|
||||
result, err := service.CreateUser(user)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
|
||||
func Test_GetUser_ExistingUser_ReturnsUser(t *testing.T) {
|
||||
user := createTestUser()
|
||||
// test implementation
|
||||
}
|
||||
|
||||
// Private helper functions at the bottom
|
||||
func createTestUser() *User {
|
||||
return &User{
|
||||
Name: "Test User",
|
||||
Email: "test@example.com",
|
||||
}
|
||||
}
|
||||
|
||||
func setupTestDatabase() *Database {
|
||||
// setup implementation
|
||||
}
|
||||
```
|
||||
|
||||
### Controller example:
|
||||
|
||||
```go
|
||||
type ProjectController struct {
|
||||
service *ProjectService
|
||||
}
|
||||
|
||||
// Public HTTP handlers first
|
||||
func (c *ProjectController) CreateProject(ctx *gin.Context) {
|
||||
var request CreateProjectRequest
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
c.handleError(ctx, err)
|
||||
return
|
||||
}
|
||||
// handler logic
|
||||
}
|
||||
|
||||
func (c *ProjectController) GetProject(ctx *gin.Context) {
|
||||
projectID := c.extractProjectID(ctx)
|
||||
// handler logic
|
||||
}
|
||||
|
||||
// Private helper methods at the bottom
|
||||
func (c *ProjectController) handleError(ctx *gin.Context, err error) {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
}
|
||||
|
||||
func (c *ProjectController) extractProjectID(ctx *gin.Context) uuid.UUID {
|
||||
return uuid.MustParse(ctx.Param("projectId"))
|
||||
}
|
||||
```
|
||||
|
||||
## Key Points:
|
||||
|
||||
- **Exported/Public** = starts with uppercase letter (CreateUser, GetProject)
|
||||
- **Unexported/Private** = starts with lowercase letter (validateUser, handleError)
|
||||
- This improves code readability by showing the public API first
|
||||
- Private helpers are implementation details, so they go at the bottom
|
||||
- Apply this rule consistently across ALL Go files in the project
|
||||
@@ -1,45 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
## Comment Guidelines
|
||||
|
||||
1. **No obvious comments** - Don't state what the code already clearly shows
|
||||
2. **Functions and variables should have meaningful names** - Code should be self-documenting
|
||||
3. **Comments for unclear code only** - Only add comments when code logic isn't immediately clear
|
||||
|
||||
## Key Principles:
|
||||
|
||||
- **Code should tell a story** - Use descriptive variable and function names
|
||||
- **Comments explain WHY, not WHAT** - The code shows what happens, comments explain business logic or complex decisions
|
||||
- **Prefer refactoring over commenting** - If code needs explaining, consider making it clearer instead
|
||||
- **API documentation is required** - Swagger comments for all HTTP endpoints are mandatory
|
||||
- **Complex algorithms deserve comments** - Mathematical formulas, business rules, or non-obvious optimizations
|
||||
|
||||
Example of useless comment:
|
||||
|
||||
1.
|
||||
|
||||
```sql
|
||||
// Create projects table
|
||||
CREATE TABLE projects (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
name TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
```
|
||||
|
||||
2.
|
||||
|
||||
```go
|
||||
// Create test project
|
||||
project := CreateTestProject(projectName, user, router)
|
||||
```
|
||||
|
||||
3.
|
||||
|
||||
```go
|
||||
// CreateValidLogItems creates valid log items for testing
|
||||
func CreateValidLogItems(count int, uniqueID string) []logs_receiving.LogItemRequestDTO {
|
||||
```
|
||||
@@ -1,133 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
1. When we write controller:
|
||||
|
||||
- we combine all routes to single controller
|
||||
- names them as .WhatWeDo (not "handlers") concept
|
||||
|
||||
2. We use gin and \*gin.Context for all routes.
|
||||
Example:
|
||||
|
||||
func (c *TasksController) GetAvailableTasks(ctx *gin.Context) ...
|
||||
|
||||
3. We document all routes with Swagger in the following format:
|
||||
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
user_models "databasus-backend/internal/features/users/models"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
)
|
||||
|
||||
type AuditLogController struct {
|
||||
auditLogService \*AuditLogService
|
||||
}
|
||||
|
||||
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// All audit log endpoints require authentication (handled in main.go)
|
||||
auditRoutes := router.Group("/audit-logs")
|
||||
|
||||
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
|
||||
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
|
||||
|
||||
}
|
||||
|
||||
// GetGlobalAuditLogs
|
||||
// @Summary Get global audit logs (ADMIN only)
|
||||
// @Description Retrieve all audit logs across the system
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/global [get]
|
||||
func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(\*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "only administrators can view global audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
|
||||
}
|
||||
|
||||
// GetUserAuditLogs
|
||||
// @Summary Get user audit logs
|
||||
// @Description Retrieve audit logs for a specific user
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param userId path string true "User ID"
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/users/{userId} [get]
|
||||
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(\*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
userIDStr := ctx.Param("userId")
|
||||
targetUserID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "insufficient permissions to view user audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
|
||||
}
|
||||
@@ -1,671 +0,0 @@
|
||||
---
|
||||
alwaysApply: false
|
||||
---
|
||||
|
||||
This is example of CRUD:
|
||||
|
||||
------ backend/internal/features/audit_logs/controller.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
user_models "databasus-backend/internal/features/users/models"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogController struct {
|
||||
auditLogService *AuditLogService
|
||||
}
|
||||
|
||||
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// All audit log endpoints require authentication (handled in main.go)
|
||||
auditRoutes := router.Group("/audit-logs")
|
||||
|
||||
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
|
||||
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
|
||||
}
|
||||
|
||||
// GetGlobalAuditLogs
|
||||
// @Summary Get global audit logs (ADMIN only)
|
||||
// @Description Retrieve all audit logs across the system
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/global [get]
|
||||
func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "only administrators can view global audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// GetUserAuditLogs
|
||||
// @Summary Get user audit logs
|
||||
// @Description Retrieve audit logs for a specific user
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param userId path string true "User ID"
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/users/{userId} [get]
|
||||
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
userIDStr := ctx.Param("userId")
|
||||
targetUserID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "insufficient permissions to view user audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/controller_test.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
"databasus-backend/internal/storage"
|
||||
test_utils "databasus-backend/internal/util/testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_GetGlobalAuditLogs_AdminSucceedsAndMemberGetsForbidden(t *testing.T) {
|
||||
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
|
||||
memberUser := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
service := GetAuditLogService()
|
||||
projectID := uuid.New()
|
||||
|
||||
// Create test logs
|
||||
createAuditLog(service, "Test log with user", &adminUser.UserID, nil)
|
||||
createAuditLog(service, "Test log with project", nil, &projectID)
|
||||
createAuditLog(service, "Test log standalone", nil, nil)
|
||||
|
||||
// Test ADMIN can access global logs
|
||||
var response GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
"/api/v1/audit-logs/global?limit=10", "Bearer "+adminUser.Token, http.StatusOK, &response)
|
||||
|
||||
assert.GreaterOrEqual(t, len(response.AuditLogs), 3)
|
||||
assert.GreaterOrEqual(t, response.Total, int64(3))
|
||||
|
||||
messages := extractMessages(response.AuditLogs)
|
||||
assert.Contains(t, messages, "Test log with user")
|
||||
assert.Contains(t, messages, "Test log with project")
|
||||
assert.Contains(t, messages, "Test log standalone")
|
||||
|
||||
// Test MEMBER cannot access global logs
|
||||
resp := test_utils.MakeGetRequest(t, router, "/api/v1/audit-logs/global",
|
||||
"Bearer "+memberUser.Token, http.StatusForbidden)
|
||||
assert.Contains(t, string(resp.Body), "only administrators can view global audit logs")
|
||||
}
|
||||
|
||||
func Test_GetUserAuditLogs_PermissionsEnforcedCorrectly(t *testing.T) {
|
||||
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
|
||||
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
service := GetAuditLogService()
|
||||
projectID := uuid.New()
|
||||
|
||||
// Create test logs for different users
|
||||
createAuditLog(service, "Test log user1 first", &user1.UserID, nil)
|
||||
createAuditLog(service, "Test log user1 second", &user1.UserID, &projectID)
|
||||
createAuditLog(service, "Test log user2 first", &user2.UserID, nil)
|
||||
createAuditLog(service, "Test log user2 second", &user2.UserID, &projectID)
|
||||
createAuditLog(service, "Test project log", nil, &projectID)
|
||||
|
||||
// Test ADMIN can view any user's logs
|
||||
var user1Response GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/users/%s?limit=10", user1.UserID.String()),
|
||||
"Bearer "+adminUser.Token, http.StatusOK, &user1Response)
|
||||
|
||||
assert.Equal(t, 2, len(user1Response.AuditLogs))
|
||||
messages := extractMessages(user1Response.AuditLogs)
|
||||
assert.Contains(t, messages, "Test log user1 first")
|
||||
assert.Contains(t, messages, "Test log user1 second")
|
||||
|
||||
// Test user can view own logs
|
||||
var ownLogsResponse GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/users/%s", user2.UserID.String()),
|
||||
"Bearer "+user2.Token, http.StatusOK, &ownLogsResponse)
|
||||
assert.Equal(t, 2, len(ownLogsResponse.AuditLogs))
|
||||
|
||||
// Test user cannot view other user's logs
|
||||
resp := test_utils.MakeGetRequest(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/users/%s", user1.UserID.String()),
|
||||
"Bearer "+user2.Token, http.StatusForbidden)
|
||||
|
||||
assert.Contains(t, string(resp.Body), "insufficient permissions")
|
||||
}
|
||||
|
||||
func Test_FilterAuditLogsByTime_ReturnsOnlyLogsBeforeDate(t *testing.T) {
|
||||
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
|
||||
router := createRouter()
|
||||
service := GetAuditLogService()
|
||||
db := storage.GetDb()
|
||||
baseTime := time.Now().UTC()
|
||||
|
||||
// Create logs with different timestamps
|
||||
createTimedLog(db, &adminUser.UserID, "Test old log", baseTime.Add(-2*time.Hour))
|
||||
createTimedLog(db, &adminUser.UserID, "Test recent log", baseTime.Add(-30*time.Minute))
|
||||
createAuditLog(service, "Test current log", &adminUser.UserID, nil)
|
||||
|
||||
// Test filtering - get logs before 1 hour ago
|
||||
beforeTime := baseTime.Add(-1 * time.Hour)
|
||||
var filteredResponse GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/global?beforeDate=%s", beforeTime.Format(time.RFC3339)),
|
||||
"Bearer "+adminUser.Token, http.StatusOK, &filteredResponse)
|
||||
|
||||
// Verify only old log is returned
|
||||
messages := extractMessages(filteredResponse.AuditLogs)
|
||||
assert.Contains(t, messages, "Test old log")
|
||||
assert.NotContains(t, messages, "Test recent log")
|
||||
assert.NotContains(t, messages, "Test current log")
|
||||
|
||||
// Test without filter - should get all logs
|
||||
var allResponse GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router, "/api/v1/audit-logs/global",
|
||||
"Bearer "+adminUser.Token, http.StatusOK, &allResponse)
|
||||
assert.GreaterOrEqual(t, len(allResponse.AuditLogs), 3)
|
||||
}
|
||||
|
||||
func createRouter() *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
SetupDependencies()
|
||||
|
||||
v1 := router.Group("/api/v1")
|
||||
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
|
||||
GetAuditLogController().RegisterRoutes(protected.(*gin.RouterGroup))
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/di.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var auditLogRepository = &AuditLogRepository{}
|
||||
var auditLogService = &AuditLogService{
|
||||
auditLogRepository: auditLogRepository,
|
||||
logger: logger.GetLogger(),
|
||||
}
|
||||
var auditLogController = &AuditLogController{
|
||||
auditLogService: auditLogService,
|
||||
}
|
||||
|
||||
func GetAuditLogService() *AuditLogService {
|
||||
return auditLogService
|
||||
}
|
||||
|
||||
func GetAuditLogController() *AuditLogController {
|
||||
return auditLogController
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
users_services.GetUserService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/dto.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import "time"
|
||||
|
||||
type GetAuditLogsRequest struct {
|
||||
Limit int `form:"limit" json:"limit"`
|
||||
Offset int `form:"offset" json:"offset"`
|
||||
BeforeDate *time.Time `form:"beforeDate" json:"beforeDate"`
|
||||
}
|
||||
|
||||
type GetAuditLogsResponse struct {
|
||||
AuditLogs []*AuditLog `json:"auditLogs"`
|
||||
Total int64 `json:"total"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/models.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLog struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id"`
|
||||
UserID *uuid.UUID `json:"userId" gorm:"column:user_id"`
|
||||
ProjectID *uuid.UUID `json:"projectId" gorm:"column:project_id"`
|
||||
Message string `json:"message" gorm:"column:message"`
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
|
||||
}
|
||||
|
||||
func (AuditLog) TableName() string {
|
||||
return "audit_logs"
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/repository.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/storage"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogRepository struct{}
|
||||
|
||||
func (r *AuditLogRepository) Create(auditLog *AuditLog) error {
|
||||
if auditLog.ID == uuid.Nil {
|
||||
auditLog.ID = uuid.New()
|
||||
}
|
||||
|
||||
return storage.GetDb().Create(auditLog).Error
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) GetGlobal(limit, offset int, beforeDate *time.Time) ([]*AuditLog, error) {
|
||||
var auditLogs []*AuditLog
|
||||
|
||||
query := storage.GetDb().Order("created_at DESC")
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&auditLogs).Error
|
||||
|
||||
return auditLogs, err
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) GetByUser(
|
||||
userID uuid.UUID,
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLog, error) {
|
||||
var auditLogs []*AuditLog
|
||||
|
||||
query := storage.GetDb().
|
||||
Where("user_id = ?", userID).
|
||||
Order("created_at DESC")
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&auditLogs).Error
|
||||
|
||||
return auditLogs, err
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) GetByProject(
|
||||
projectID uuid.UUID,
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLog, error) {
|
||||
var auditLogs []*AuditLog
|
||||
|
||||
query := storage.GetDb().
|
||||
Where("project_id = ?", projectID).
|
||||
Order("created_at DESC")
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&auditLogs).Error
|
||||
|
||||
return auditLogs, err
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) CountGlobal(beforeDate *time.Time) (int64, error) {
|
||||
var count int64
|
||||
query := storage.GetDb().Model(&AuditLog{})
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/service.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
user_models "databasus-backend/internal/features/users/models"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogService struct {
|
||||
auditLogRepository *AuditLogRepository
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *AuditLogService) WriteAuditLog(
|
||||
message string,
|
||||
userID *uuid.UUID,
|
||||
projectID *uuid.UUID,
|
||||
) {
|
||||
auditLog := &AuditLog{
|
||||
UserID: userID,
|
||||
ProjectID: projectID,
|
||||
Message: message,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
err := s.auditLogRepository.Create(auditLog)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to create audit log", "error", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuditLogService) CreateAuditLog(auditLog *AuditLog) error {
|
||||
return s.auditLogRepository.Create(auditLog)
|
||||
}
|
||||
|
||||
func (s *AuditLogService) GetGlobalAuditLogs(
|
||||
user *user_models.User,
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
if user.Role != user_enums.UserRoleAdmin {
|
||||
return nil, errors.New("only administrators can view global audit logs")
|
||||
}
|
||||
|
||||
limit := request.Limit
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
offset := max(request.Offset, 0)
|
||||
|
||||
auditLogs, err := s.auditLogRepository.GetGlobal(limit, offset, request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
total, err := s.auditLogRepository.CountGlobal(request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetAuditLogsResponse{
|
||||
AuditLogs: auditLogs,
|
||||
Total: total,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AuditLogService) GetUserAuditLogs(
|
||||
targetUserID uuid.UUID,
|
||||
user *user_models.User,
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
// Users can view their own logs, ADMIN can view any user's logs
|
||||
if user.Role != user_enums.UserRoleAdmin && user.ID != targetUserID {
|
||||
return nil, errors.New("insufficient permissions to view user audit logs")
|
||||
}
|
||||
|
||||
limit := request.Limit
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
offset := max(request.Offset, 0)
|
||||
|
||||
auditLogs, err := s.auditLogRepository.GetByUser(targetUserID, limit, offset, request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetAuditLogsResponse{
|
||||
AuditLogs: auditLogs,
|
||||
Total: int64(len(auditLogs)),
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AuditLogService) GetProjectAuditLogs(
|
||||
projectID uuid.UUID,
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
limit := request.Limit
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
offset := max(request.Offset, 0)
|
||||
|
||||
auditLogs, err := s.auditLogRepository.GetByProject(projectID, limit, offset, request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetAuditLogsResponse{
|
||||
AuditLogs: auditLogs,
|
||||
Total: int64(len(auditLogs)),
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/service_test.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
user_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func Test_AuditLogs_ProjectSpecificLogs(t *testing.T) {
|
||||
service := GetAuditLogService()
|
||||
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
project1ID, project2ID := uuid.New(), uuid.New()
|
||||
|
||||
// Create test logs for projects
|
||||
createAuditLog(service, "Test project1 log first", &user1.UserID, &project1ID)
|
||||
createAuditLog(service, "Test project1 log second", &user2.UserID, &project1ID)
|
||||
createAuditLog(service, "Test project2 log first", &user1.UserID, &project2ID)
|
||||
createAuditLog(service, "Test project2 log second", &user2.UserID, &project2ID)
|
||||
createAuditLog(service, "Test no project log", &user1.UserID, nil)
|
||||
|
||||
request := &GetAuditLogsRequest{Limit: 10, Offset: 0}
|
||||
|
||||
// Test project 1 logs
|
||||
project1Response, err := service.GetProjectAuditLogs(project1ID, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(project1Response.AuditLogs))
|
||||
|
||||
messages := extractMessages(project1Response.AuditLogs)
|
||||
assert.Contains(t, messages, "Test project1 log first")
|
||||
assert.Contains(t, messages, "Test project1 log second")
|
||||
for _, log := range project1Response.AuditLogs {
|
||||
assert.Equal(t, &project1ID, log.ProjectID)
|
||||
}
|
||||
|
||||
// Test project 2 logs
|
||||
project2Response, err := service.GetProjectAuditLogs(project2ID, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(project2Response.AuditLogs))
|
||||
|
||||
messages2 := extractMessages(project2Response.AuditLogs)
|
||||
assert.Contains(t, messages2, "Test project2 log first")
|
||||
assert.Contains(t, messages2, "Test project2 log second")
|
||||
|
||||
// Test pagination
|
||||
limitedResponse, err := service.GetProjectAuditLogs(project1ID,
|
||||
&GetAuditLogsRequest{Limit: 1, Offset: 0})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(limitedResponse.AuditLogs))
|
||||
assert.Equal(t, 1, limitedResponse.Limit)
|
||||
|
||||
// Test beforeDate filter
|
||||
beforeTime := time.Now().UTC().Add(-1 * time.Minute)
|
||||
filteredResponse, err := service.GetProjectAuditLogs(project1ID,
|
||||
&GetAuditLogsRequest{Limit: 10, BeforeDate: &beforeTime})
|
||||
assert.NoError(t, err)
|
||||
for _, log := range filteredResponse.AuditLogs {
|
||||
assert.True(t, log.CreatedAt.Before(beforeTime))
|
||||
}
|
||||
}
|
||||
|
||||
func createAuditLog(service *AuditLogService, message string, userID, projectID *uuid.UUID) {
|
||||
service.WriteAuditLog(message, userID, projectID)
|
||||
}
|
||||
|
||||
func extractMessages(logs []*AuditLog) []string {
|
||||
messages := make([]string, len(logs))
|
||||
for i, log := range logs {
|
||||
messages[i] = log.Message
|
||||
}
|
||||
return messages
|
||||
}
|
||||
|
||||
func createTimedLog(db *gorm.DB, userID *uuid.UUID, message string, createdAt time.Time) {
|
||||
log := &AuditLog{
|
||||
ID: uuid.New(),
|
||||
UserID: userID,
|
||||
Message: message,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
db.Create(log)
|
||||
}
|
||||
|
||||
```
|
||||
@@ -1,74 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
For DI files use implicit fields declaration styles (espesially
|
||||
for controllers, services, repositories, use cases, etc., not simple
|
||||
data structures).
|
||||
|
||||
So, instead of:
|
||||
|
||||
var orderController = &OrderController{
|
||||
orderService: orderService,
|
||||
botUserService: bot_users.GetBotUserService(),
|
||||
botService: bots.GetBotService(),
|
||||
userService: users.GetUserService(),
|
||||
}
|
||||
|
||||
Use:
|
||||
|
||||
var orderController = &OrderController{
|
||||
orderService,
|
||||
bot_users.GetBotUserService(),
|
||||
bots.GetBotService(),
|
||||
users.GetUserService(),
|
||||
}
|
||||
|
||||
This is needed to avoid forgetting to update DI style
|
||||
when we add new dependency.
|
||||
|
||||
---
|
||||
|
||||
Please force such usage if file look like this (see some
|
||||
services\controllers\repos definitions and getters):
|
||||
|
||||
var orderBackgroundService = &OrderBackgroundService{
|
||||
orderService: orderService,
|
||||
orderPaymentRepository: orderPaymentRepository,
|
||||
botService: bots.GetBotService(),
|
||||
paymentSettingsService: payment_settings.GetPaymentSettingsService(),
|
||||
|
||||
orderSubscriptionListeners: []OrderSubscriptionListener{},
|
||||
}
|
||||
|
||||
var orderController = &OrderController{
|
||||
orderService: orderService,
|
||||
botUserService: bot_users.GetBotUserService(),
|
||||
botService: bots.GetBotService(),
|
||||
userService: users.GetUserService(),
|
||||
}
|
||||
|
||||
func GetUniquePaymentRepository() *repositories.UniquePaymentRepository {
|
||||
return uniquePaymentRepository
|
||||
}
|
||||
|
||||
func GetOrderPaymentRepository() *repositories.OrderPaymentRepository {
|
||||
return orderPaymentRepository
|
||||
}
|
||||
|
||||
func GetOrderService() *OrderService {
|
||||
return orderService
|
||||
}
|
||||
|
||||
func GetOrderController() *OrderController {
|
||||
return orderController
|
||||
}
|
||||
|
||||
func GetOrderBackgroundService() *OrderBackgroundService {
|
||||
return orderBackgroundService
|
||||
}
|
||||
|
||||
func GetOrderRepository() *repositories.OrderRepository {
|
||||
return orderRepository
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
When writting migrations:
|
||||
|
||||
- write them for PostgreSQL
|
||||
- for PRIMARY UUID keys use gen_random_uuid()
|
||||
- for time use TIMESTAMPTZ (timestamp with zone)
|
||||
- split table, constraint and indexes declaration (table first, them other one by one)
|
||||
- format SQL in pretty way (add spaces, align columns types), constraints split by lines. The example:
|
||||
|
||||
CREATE TABLE marketplace_info (
|
||||
bot_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
title TEXT NOT NULL,
|
||||
description TEXT NOT NULL,
|
||||
short_description TEXT NOT NULL,
|
||||
tutorial_url TEXT,
|
||||
info_order BIGINT NOT NULL DEFAULT 0,
|
||||
is_published BOOLEAN NOT NULL DEFAULT FALSE
|
||||
);
|
||||
|
||||
ALTER TABLE marketplace_info_images
|
||||
ADD CONSTRAINT fk_marketplace_info_images_bot_id
|
||||
FOREIGN KEY (bot_id)
|
||||
REFERENCES marketplace_info (bot_id);
|
||||
@@ -1,12 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
When applying changes, do not forget to refactor old code.
|
||||
|
||||
You can shortify, make more readable, improve code quality, etc.
|
||||
Common logic can be extracted to functions, constants, files, etc.
|
||||
|
||||
After each large change with more than ~50-100 lines of code - always run `make lint` (from backend root folder) and, if you change frontend, run `npm run format` (from frontend root folder).
|
||||
@@ -1,147 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
After writing tests, always launch them and verify that they pass.
|
||||
|
||||
## Test Naming Format
|
||||
|
||||
Use these naming patterns:
|
||||
|
||||
- `Test_WhatWeDo_WhatWeExpect`
|
||||
- `Test_WhatWeDo_WhichConditions_WhatWeExpect`
|
||||
|
||||
## Examples from Real Codebase:
|
||||
|
||||
- `Test_CreateApiKey_WhenUserIsProjectOwner_ApiKeyCreated`
|
||||
- `Test_UpdateProject_WhenUserIsProjectAdmin_ProjectUpdated`
|
||||
- `Test_DeleteApiKey_WhenUserIsProjectMember_ReturnsForbidden`
|
||||
- `Test_GetProjectAuditLogs_WithDifferentUserRoles_EnforcesPermissionsCorrectly`
|
||||
- `Test_ProjectLifecycleE2E_CompletesSuccessfully`
|
||||
|
||||
## Testing Philosophy
|
||||
|
||||
**Prefer Controllers Over Unit Tests:**
|
||||
|
||||
- Test through HTTP endpoints via controllers whenever possible
|
||||
- Avoid testing repositories, services in isolation - test via API instead
|
||||
- Only use unit tests for complex model logic when no API exists
|
||||
- Name test files `controller_test.go` or `service_test.go`, not `integration_test.go`
|
||||
|
||||
**Extract Common Logic to Testing Utilities:**
|
||||
|
||||
- Create `testing.go` or `testing/testing.go` files for shared test utilities
|
||||
- Extract router creation, user setup, models creation helpers (in API, not just structs creation)
|
||||
- Reuse common patterns across different test files
|
||||
|
||||
**Refactor Existing Tests:**
|
||||
|
||||
- When working with existing tests, always look for opportunities to refactor and improve
|
||||
- Extract repetitive setup code to common utilities
|
||||
- Simplify complex tests by breaking them into smaller, focused tests
|
||||
- Replace inline test data creation with reusable helper functions
|
||||
- Consolidate similar test patterns across different test files
|
||||
- Make tests more readable and maintainable for other developers
|
||||
|
||||
## Testing Utilities Structure
|
||||
|
||||
**Create `testing.go` or `testing/testing.go` files with common utilities:**
|
||||
|
||||
```go
|
||||
package projects_testing
|
||||
|
||||
// CreateTestRouter creates unified router for all controllers
|
||||
func CreateTestRouter(controllers ...ControllerInterface) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
v1 := router.Group("/api/v1")
|
||||
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
|
||||
|
||||
for _, controller := range controllers {
|
||||
if routerGroup, ok := protected.(*gin.RouterGroup); ok {
|
||||
controller.RegisterRoutes(routerGroup)
|
||||
}
|
||||
}
|
||||
return router
|
||||
}
|
||||
|
||||
// CreateTestProjectViaAPI creates project through HTTP API
|
||||
func CreateTestProjectViaAPI(name string, owner *users_dto.SignInResponseDTO, router *gin.Engine) (*projects_models.Project, string) {
|
||||
request := projects_dto.CreateProjectRequestDTO{Name: name}
|
||||
w := MakeAPIRequest(router, "POST", "/api/v1/projects", "Bearer "+owner.Token, request)
|
||||
// Handle response...
|
||||
return project, owner.Token
|
||||
}
|
||||
|
||||
// AddMemberToProject adds member via API call
|
||||
func AddMemberToProject(project *projects_models.Project, member *users_dto.SignInResponseDTO, role users_enums.ProjectRole, ownerToken string, router *gin.Engine) {
|
||||
// Implementation...
|
||||
}
|
||||
```
|
||||
|
||||
## Controller Test Examples
|
||||
|
||||
**Permission-based testing:**
|
||||
|
||||
```go
|
||||
func Test_CreateApiKey_WhenUserIsProjectOwner_ApiKeyCreated(t *testing.T) {
|
||||
router := CreateApiKeyTestRouter(GetProjectController(), GetMembershipController())
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
project, _ := projects_testing.CreateTestProjectViaAPI("Test Project", owner, router)
|
||||
|
||||
request := CreateApiKeyRequestDTO{Name: "Test API Key"}
|
||||
var response ApiKey
|
||||
test_utils.MakePostRequestAndUnmarshal(t, router, "/api/v1/projects/api-keys/"+project.ID.String(), "Bearer "+owner.Token, request, http.StatusOK, &response)
|
||||
|
||||
assert.Equal(t, "Test API Key", response.Name)
|
||||
assert.NotEmpty(t, response.Token)
|
||||
}
|
||||
```
|
||||
|
||||
**Cross-project security testing:**
|
||||
|
||||
```go
|
||||
func Test_UpdateApiKey_WithApiKeyFromDifferentProject_ReturnsBadRequest(t *testing.T) {
|
||||
router := CreateApiKeyTestRouter(GetProjectController(), GetMembershipController())
|
||||
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
project1, _ := projects_testing.CreateTestProjectViaAPI("Project 1", owner1, router)
|
||||
project2, _ := projects_testing.CreateTestProjectViaAPI("Project 2", owner2, router)
|
||||
|
||||
apiKey := CreateTestApiKey("Cross Project Key", project1.ID, owner1.Token, router)
|
||||
|
||||
// Try to update via different project endpoint
|
||||
request := UpdateApiKeyRequestDTO{Name: &"Hacked Key"}
|
||||
resp := test_utils.MakePutRequest(t, router, "/api/v1/projects/api-keys/"+project2.ID.String()+"/"+apiKey.ID.String(), "Bearer "+owner2.Token, request, http.StatusBadRequest)
|
||||
|
||||
assert.Contains(t, string(resp.Body), "API key does not belong to this project")
|
||||
}
|
||||
```
|
||||
|
||||
**E2E lifecycle testing:**
|
||||
|
||||
```go
|
||||
func Test_ProjectLifecycleE2E_CompletesSuccessfully(t *testing.T) {
|
||||
router := projects_testing.CreateTestRouter(GetProjectController(), GetMembershipController())
|
||||
|
||||
// 1. Create project
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
project := projects_testing.CreateTestProject("E2E Project", owner, router)
|
||||
|
||||
// 2. Add member
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
projects_testing.AddMemberToProject(project, member, users_enums.ProjectRoleMember, owner.Token, router)
|
||||
|
||||
// 3. Promote to admin
|
||||
projects_testing.ChangeMemberRole(project, member.UserID, users_enums.ProjectRoleAdmin, owner.Token, router)
|
||||
|
||||
// 4. Transfer ownership
|
||||
projects_testing.TransferProjectOwnership(project, member.UserID, owner.Token, router)
|
||||
|
||||
// 5. Verify new owner can manage project
|
||||
finalProject := projects_testing.GetProject(project.ID, member.Token, router)
|
||||
assert.Equal(t, project.ID, finalProject.ID)
|
||||
}
|
||||
```
|
||||
@@ -1,6 +0,0 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
Always use time.Now().UTC() instead of time.Now()
|
||||
@@ -2,8 +2,13 @@
|
||||
DEV_DB_NAME=databasus
|
||||
DEV_DB_USERNAME=postgres
|
||||
DEV_DB_PASSWORD=Q1234567
|
||||
#app
|
||||
# app
|
||||
ENV_MODE=development
|
||||
# logging
|
||||
SHOW_DB_INSTALLATION_VERIFICATION_LOGS=true
|
||||
# tests
|
||||
TEST_LOCALHOST=localhost
|
||||
IS_SKIP_EXTERNAL_RESOURCES_TESTS=false
|
||||
# db
|
||||
DATABASE_DSN=host=dev-db user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
|
||||
DATABASE_URL=postgres://postgres:Q1234567@dev-db:5437/databasus?sslmode=disable
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -26,8 +25,10 @@ import (
|
||||
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/restores"
|
||||
"databasus-backend/internal/features/restores/restoring"
|
||||
"databasus-backend/internal/features/storages"
|
||||
system_healthcheck "databasus-backend/internal/features/system/healthcheck"
|
||||
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
users_controllers "databasus-backend/internal/features/users/controllers"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
@@ -59,6 +60,8 @@ func main() {
|
||||
cache_utils.TestCacheConnection()
|
||||
|
||||
if config.GetEnv().IsPrimaryNode {
|
||||
log.Info("Clearing cache...")
|
||||
|
||||
err := cache_utils.ClearAllCache()
|
||||
if err != nil {
|
||||
log.Error("Failed to clear cache", "error", err)
|
||||
@@ -239,7 +242,7 @@ func setUpDependencies() {
|
||||
notifiers.SetupDependencies()
|
||||
storages.SetupDependencies()
|
||||
backups_config.SetupDependencies()
|
||||
backups_cancellation.SetupDependencies()
|
||||
task_cancellation.SetupDependencies()
|
||||
}
|
||||
|
||||
func runBackgroundTasks(log *slog.Logger) {
|
||||
@@ -257,20 +260,24 @@ func runBackgroundTasks(log *slog.Logger) {
|
||||
cancel()
|
||||
}()
|
||||
|
||||
err := files_utils.CleanFolder(config.GetEnv().TempFolder)
|
||||
if err != nil {
|
||||
log.Error("Failed to clean temp folder", "error", err)
|
||||
}
|
||||
|
||||
if config.GetEnv().IsPrimaryNode {
|
||||
log.Info("Starting primary node background tasks...")
|
||||
|
||||
err := files_utils.CleanFolder(config.GetEnv().TempFolder)
|
||||
if err != nil {
|
||||
log.Error("Failed to clean temp folder", "error", err)
|
||||
}
|
||||
|
||||
go runWithPanicLogging(log, "backup background service", func() {
|
||||
backuping.GetBackupsScheduler().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "backup cleaner background service", func() {
|
||||
backuping.GetBackupCleaner().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "restore background service", func() {
|
||||
restores.GetRestoreBackgroundService().Run(ctx)
|
||||
restoring.GetRestoresScheduler().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "healthcheck attempt background service", func() {
|
||||
@@ -284,18 +291,30 @@ func runBackgroundTasks(log *slog.Logger) {
|
||||
go runWithPanicLogging(log, "download token cleanup background service", func() {
|
||||
backups_download.GetDownloadTokenBackgroundService().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "backup nodes registry background service", func() {
|
||||
backuping.GetBackupNodesRegistry().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "restore nodes registry background service", func() {
|
||||
restoring.GetRestoreNodesRegistry().Run(ctx)
|
||||
})
|
||||
} else {
|
||||
log.Info("Skipping primary node tasks as not primary node")
|
||||
}
|
||||
|
||||
if config.GetEnv().IsBackupNode {
|
||||
if config.GetEnv().IsProcessingNode {
|
||||
log.Info("Starting backup node background tasks...")
|
||||
|
||||
go runWithPanicLogging(log, "backup node", func() {
|
||||
backuping.GetBackuperNode().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "restore node", func() {
|
||||
restoring.GetRestorerNode().Run(ctx)
|
||||
})
|
||||
} else {
|
||||
log.Info("Skipping backup node tasks as not backup node")
|
||||
log.Info("Skipping backup/restore node tasks as not backup node")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/ilyakaznacheev/cleanenv"
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
@@ -30,10 +29,14 @@ type EnvVariables struct {
|
||||
MariadbInstallDir string `env:"MARIADB_INSTALL_DIR"`
|
||||
MongodbInstallDir string `env:"MONGODB_INSTALL_DIR"`
|
||||
|
||||
NodeID string
|
||||
TestLocalhost string `env:"TEST_LOCALHOST"`
|
||||
|
||||
ShowDbInstallationVerificationLogs bool `env:"SHOW_DB_INSTALLATION_VERIFICATION_LOGS"`
|
||||
IsSkipExternalResourcesTests bool `env:"IS_SKIP_EXTERNAL_RESOURCES_TESTS"`
|
||||
|
||||
IsManyNodesMode bool `env:"IS_MANY_NODES_MODE"`
|
||||
IsPrimaryNode bool `env:"IS_PRIMARY_NODE"`
|
||||
IsBackupNode bool `env:"IS_BACKUP_NODE"`
|
||||
IsProcessingNode bool `env:"IS_PROCESSING_NODE"`
|
||||
NodeNetworkThroughputMBs int `env:"NODE_NETWORK_THROUGHPUT_MBPS"`
|
||||
|
||||
DataFolder string
|
||||
@@ -169,6 +172,16 @@ func loadEnvVariables() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Set default value for ShowDbInstallationVerificationLogs if not defined
|
||||
if os.Getenv("SHOW_DB_INSTALLATION_VERIFICATION_LOGS") == "" {
|
||||
env.ShowDbInstallationVerificationLogs = true
|
||||
}
|
||||
|
||||
// Set default value for IsSkipExternalTests if not defined
|
||||
if os.Getenv("IS_SKIP_EXTERNAL_RESOURCES_TESTS") == "" {
|
||||
env.IsSkipExternalResourcesTests = false
|
||||
}
|
||||
|
||||
for _, arg := range os.Args {
|
||||
if strings.Contains(arg, "test") {
|
||||
env.IsTesting = true
|
||||
@@ -192,25 +205,48 @@ func loadEnvVariables() {
|
||||
log.Info("ENV_MODE loaded", "mode", env.EnvMode)
|
||||
|
||||
env.PostgresesInstallDir = filepath.Join(backendRoot, "tools", "postgresql")
|
||||
tools.VerifyPostgresesInstallation(log, env.EnvMode, env.PostgresesInstallDir)
|
||||
tools.VerifyPostgresesInstallation(
|
||||
log,
|
||||
env.EnvMode,
|
||||
env.PostgresesInstallDir,
|
||||
env.ShowDbInstallationVerificationLogs,
|
||||
)
|
||||
|
||||
env.MysqlInstallDir = filepath.Join(backendRoot, "tools", "mysql")
|
||||
tools.VerifyMysqlInstallation(log, env.EnvMode, env.MysqlInstallDir)
|
||||
tools.VerifyMysqlInstallation(
|
||||
log,
|
||||
env.EnvMode,
|
||||
env.MysqlInstallDir,
|
||||
env.ShowDbInstallationVerificationLogs,
|
||||
)
|
||||
|
||||
env.MariadbInstallDir = filepath.Join(backendRoot, "tools", "mariadb")
|
||||
tools.VerifyMariadbInstallation(log, env.EnvMode, env.MariadbInstallDir)
|
||||
tools.VerifyMariadbInstallation(
|
||||
log,
|
||||
env.EnvMode,
|
||||
env.MariadbInstallDir,
|
||||
env.ShowDbInstallationVerificationLogs,
|
||||
)
|
||||
|
||||
env.MongodbInstallDir = filepath.Join(backendRoot, "tools", "mongodb")
|
||||
tools.VerifyMongodbInstallation(log, env.EnvMode, env.MongodbInstallDir)
|
||||
tools.VerifyMongodbInstallation(
|
||||
log,
|
||||
env.EnvMode,
|
||||
env.MongodbInstallDir,
|
||||
env.ShowDbInstallationVerificationLogs,
|
||||
)
|
||||
|
||||
env.NodeID = uuid.New().String()
|
||||
if env.NodeNetworkThroughputMBs == 0 {
|
||||
env.NodeNetworkThroughputMBs = 125 // 1 Gbit/s
|
||||
}
|
||||
|
||||
if !env.IsManyNodesMode {
|
||||
env.IsPrimaryNode = true
|
||||
env.IsBackupNode = true
|
||||
env.IsProcessingNode = true
|
||||
}
|
||||
|
||||
if env.TestLocalhost == "" {
|
||||
env.TestLocalhost = "localhost"
|
||||
}
|
||||
|
||||
// Valkey
|
||||
|
||||
@@ -2,34 +2,50 @@ package audit_logs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type AuditLogBackgroundService struct {
|
||||
auditLogService *AuditLogService
|
||||
logger *slog.Logger
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *AuditLogBackgroundService) Run(ctx context.Context) {
|
||||
s.logger.Info("Starting audit log cleanup background service")
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
s.logger.Info("Starting audit log cleanup background service")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.cleanOldAuditLogs(); err != nil {
|
||||
s.logger.Error("Failed to clean old audit logs", "error", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.cleanOldAuditLogs(); err != nil {
|
||||
s.logger.Error("Failed to clean old audit logs", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
@@ -14,8 +17,10 @@ var auditLogController = &AuditLogController{
|
||||
auditLogService,
|
||||
}
|
||||
var auditLogBackgroundService = &AuditLogBackgroundService{
|
||||
auditLogService,
|
||||
logger.GetLogger(),
|
||||
auditLogService: auditLogService,
|
||||
logger: logger.GetLogger(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
|
||||
func GetAuditLogService() *AuditLogService {
|
||||
@@ -30,8 +35,23 @@ func GetAuditLogBackgroundService() *AuditLogBackgroundService {
|
||||
return auditLogBackgroundService
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
users_services.GetUserService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
users_services.GetUserService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,22 +2,30 @@ package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"databasus-backend/internal/config"
|
||||
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/storages"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/storages"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
)
|
||||
|
||||
const (
|
||||
heartbeatTickerInterval = 15 * time.Second
|
||||
backuperHeathcheckThreshold = 5 * time.Minute
|
||||
)
|
||||
|
||||
type BackuperNode struct {
|
||||
@@ -28,77 +36,93 @@ type BackuperNode struct {
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
storageService *storages.StorageService
|
||||
notificationSender backups_core.NotificationSender
|
||||
backupCancelManager *backups_cancellation.BackupCancelManager
|
||||
nodesRegistry *BackupNodesRegistry
|
||||
backupCancelManager *tasks_cancellation.TaskCancelManager
|
||||
backupNodesRegistry *BackupNodesRegistry
|
||||
logger *slog.Logger
|
||||
createBackupUseCase backups_core.CreateBackupUsecase
|
||||
nodeID uuid.UUID
|
||||
|
||||
lastHeartbeat time.Time
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (n *BackuperNode) Run(ctx context.Context) {
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
wasAlreadyRun := n.hasRun.Load()
|
||||
|
||||
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
|
||||
n.runOnce.Do(func() {
|
||||
n.hasRun.Store(true)
|
||||
|
||||
backupNode := BackupNode{
|
||||
ID: n.nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: time.Now().UTC(),
|
||||
}
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
|
||||
if err := n.nodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
|
||||
n.logger.Error("Failed to register node in registry", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
|
||||
|
||||
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
|
||||
n.MakeBackup(backupID, isCallNotifier)
|
||||
if err := n.nodesRegistry.PublishBackupCompletion(n.nodeID.String(), backupID); err != nil {
|
||||
n.logger.Error(
|
||||
"Failed to publish backup completion",
|
||||
"error",
|
||||
err,
|
||||
"backupID",
|
||||
backupID,
|
||||
)
|
||||
backupNode := BackupNode{
|
||||
ID: n.nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: time.Now().UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
if err := n.nodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID.String(), backupHandler); err != nil {
|
||||
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := n.nodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
|
||||
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
|
||||
if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
|
||||
n.logger.Error("Failed to register node in registry", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(15 * time.Second)
|
||||
defer ticker.Stop()
|
||||
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
|
||||
go func() {
|
||||
n.MakeBackup(backupID, isCallNotifier)
|
||||
if err := n.backupNodesRegistry.PublishBackupCompletion(n.nodeID, backupID); err != nil {
|
||||
n.logger.Error(
|
||||
"Failed to publish backup completion",
|
||||
"error",
|
||||
err,
|
||||
"backupID",
|
||||
backupID,
|
||||
)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
|
||||
|
||||
if err := n.nodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
|
||||
n.logger.Error("Failed to unregister node from registry", "error", err)
|
||||
err := n.backupNodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID, backupHandler)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := n.backupNodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
|
||||
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return
|
||||
case <-ticker.C:
|
||||
n.sendHeartbeat(&backupNode)
|
||||
ticker := time.NewTicker(heartbeatTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
|
||||
|
||||
if err := n.backupNodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
|
||||
n.logger.Error("Failed to unregister node from registry", "error", err)
|
||||
}
|
||||
|
||||
return
|
||||
case <-ticker.C:
|
||||
n.sendHeartbeat(&backupNode)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", n))
|
||||
}
|
||||
}
|
||||
|
||||
func (n *BackuperNode) IsBackuperRunning() bool {
|
||||
return n.lastHeartbeat.After(time.Now().UTC().Add(-5 * time.Minute))
|
||||
return n.lastHeartbeat.After(time.Now().UTC().Add(-backuperHeathcheckThreshold))
|
||||
}
|
||||
|
||||
func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
|
||||
@@ -135,21 +159,41 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
|
||||
|
||||
start := time.Now().UTC()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
n.backupCancelManager.RegisterTask(backup.ID, cancel)
|
||||
defer n.backupCancelManager.UnregisterTask(backup.ID)
|
||||
|
||||
backupProgressListener := func(
|
||||
completedMBs float64,
|
||||
) {
|
||||
backup.BackupSizeMb = completedMBs
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
// Check size limit (0 = unlimited)
|
||||
if backupConfig.MaxBackupSizeMB > 0 &&
|
||||
completedMBs > float64(backupConfig.MaxBackupSizeMB) {
|
||||
errMsg := fmt.Sprintf(
|
||||
"backup size (%.2f MB) exceeded maximum allowed size (%d MB)",
|
||||
completedMBs,
|
||||
backupConfig.MaxBackupSizeMB,
|
||||
)
|
||||
|
||||
backup.Status = backups_core.BackupStatusFailed
|
||||
backup.IsSkipRetry = true
|
||||
backup.FailMessage = &errMsg
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to save backup with size exceeded error", "error", err)
|
||||
}
|
||||
cancel() // Cancel the backup context
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to update backup progress", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
n.backupCancelManager.RegisterBackup(backup.ID, cancel)
|
||||
defer n.backupCancelManager.UnregisterBackup(backup.ID)
|
||||
|
||||
backupMetadata, err := n.createBackupUseCase.Execute(
|
||||
ctx,
|
||||
backup.ID,
|
||||
@@ -159,8 +203,42 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
|
||||
backupProgressListener,
|
||||
)
|
||||
if err != nil {
|
||||
// Check if backup was already marked as failed by progress listener (e.g., size limit exceeded)
|
||||
// If so, skip error handling to avoid overwriting the status
|
||||
currentBackup, fetchErr := n.backupRepository.FindByID(backup.ID)
|
||||
if fetchErr == nil && currentBackup.Status == backups_core.BackupStatusFailed {
|
||||
n.logger.Warn(
|
||||
"Backup already marked as failed by progress listener, skipping error handling",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"failMessage",
|
||||
*currentBackup.FailMessage,
|
||||
)
|
||||
|
||||
// Still call notification for size limit failures
|
||||
n.SendBackupNotification(
|
||||
backupConfig,
|
||||
currentBackup,
|
||||
backups_config.NotificationBackupFailed,
|
||||
currentBackup.FailMessage,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
errMsg := err.Error()
|
||||
|
||||
// Log detailed error information for debugging
|
||||
n.logger.Error("Backup execution failed",
|
||||
"backupId", backup.ID,
|
||||
"databaseId", databaseID,
|
||||
"databaseType", database.Type,
|
||||
"storageId", storage.ID,
|
||||
"storageType", storage.Type,
|
||||
"error", err,
|
||||
"errorMessage", errMsg,
|
||||
)
|
||||
|
||||
// Check if backup was cancelled (not due to shutdown)
|
||||
isCancelled := strings.Contains(errMsg, "backup cancelled") ||
|
||||
strings.Contains(errMsg, "context canceled") ||
|
||||
@@ -168,6 +246,12 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
|
||||
isShutdown := strings.Contains(errMsg, "shutdown")
|
||||
|
||||
if isCancelled && !isShutdown {
|
||||
n.logger.Warn("Backup was cancelled by user or system",
|
||||
"backupId", backup.ID,
|
||||
"isCancelled", isCancelled,
|
||||
"isShutdown", isShutdown,
|
||||
)
|
||||
|
||||
backup.Status = backups_core.BackupStatusCanceled
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
backup.BackupSizeMb = 0
|
||||
@@ -337,8 +421,7 @@ func (n *BackuperNode) SendBackupNotification(
|
||||
|
||||
func (n *BackuperNode) sendHeartbeat(backupNode *BackupNode) {
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
backupNode.LastHeartbeat = time.Now().UTC()
|
||||
if err := n.nodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *backupNode); err != nil {
|
||||
if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *backupNode); err != nil {
|
||||
n.logger.Error("Failed to send heartbeat", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
common "databasus-backend/internal/features/backups/backups/common"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -18,7 +15,6 @@ import (
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
@@ -158,35 +154,120 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
type CreateFailedBackupUsecase struct {
|
||||
}
|
||||
func Test_BackupSizeLimits(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
func (uc *CreateFailedBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
backupProgressListener(10)
|
||||
return nil, errors.New("backup failed")
|
||||
}
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
type CreateSuccessBackupUsecase struct{}
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
func (uc *CreateSuccessBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
backupProgressListener(10)
|
||||
return &common.BackupMetadata{
|
||||
EncryptionSalt: nil,
|
||||
EncryptionIV: nil,
|
||||
Encryption: backups_config.BackupEncryptionNone,
|
||||
}, nil
|
||||
t.Run("UnlimitedSize_MaxBackupSizeMBIsZero_BackupCompletes", func(t *testing.T) {
|
||||
// Enable backups with unlimited size (0)
|
||||
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
backupConfig.MaxBackupSizeMB = 0 // unlimited
|
||||
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.createBackupUseCase = &CreateLargeBackupUsecase{}
|
||||
|
||||
// Create a backup record
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode.MakeBackup(backup.ID, false)
|
||||
|
||||
// Verify backup completed successfully even with large size
|
||||
updatedBackup, err := backupRepository.FindByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, updatedBackup.Status)
|
||||
assert.Equal(t, float64(10000), updatedBackup.BackupSizeMb)
|
||||
assert.Nil(t, updatedBackup.FailMessage)
|
||||
})
|
||||
|
||||
t.Run("SizeExceeded_BackupFailedWithIsSkipRetry", func(t *testing.T) {
|
||||
// Enable backups with 5 MB limit
|
||||
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
backupConfig.MaxBackupSizeMB = 5
|
||||
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.createBackupUseCase = &CreateProgressiveBackupUsecase{}
|
||||
|
||||
// Create a backup record
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode.MakeBackup(backup.ID, false)
|
||||
|
||||
// Verify backup was marked as failed with IsSkipRetry=true
|
||||
updatedBackup, err := backupRepository.FindByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusFailed, updatedBackup.Status)
|
||||
assert.True(t, updatedBackup.IsSkipRetry)
|
||||
assert.NotNil(t, updatedBackup.FailMessage)
|
||||
assert.Contains(t, *updatedBackup.FailMessage, "exceeded maximum allowed size")
|
||||
assert.Contains(t, *updatedBackup.FailMessage, "10.00 MB")
|
||||
assert.Contains(t, *updatedBackup.FailMessage, "5 MB")
|
||||
assert.Greater(t, updatedBackup.BackupSizeMb, float64(5))
|
||||
})
|
||||
|
||||
t.Run("SizeWithinLimit_BackupCompletes", func(t *testing.T) {
|
||||
// Enable backups with 100 MB limit
|
||||
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
backupConfig.MaxBackupSizeMB = 100
|
||||
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.createBackupUseCase = &CreateMediumBackupUsecase{}
|
||||
|
||||
// Create a backup record
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backuperNode.MakeBackup(backup.ID, false)
|
||||
|
||||
// Verify backup completed successfully
|
||||
updatedBackup, err := backupRepository.FindByID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, updatedBackup.Status)
|
||||
assert.Equal(t, float64(50), updatedBackup.BackupSizeMb)
|
||||
assert.Nil(t, updatedBackup.FailMessage)
|
||||
})
|
||||
}
|
||||
|
||||
242
backend/internal/features/backups/backups/backuping/cleaner.go
Normal file
242
backend/internal/features/backups/backups/backuping/cleaner.go
Normal file
@@ -0,0 +1,242 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/storages"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/period"
|
||||
)
|
||||
|
||||
const (
|
||||
cleanerTickerInterval = 1 * time.Minute
|
||||
)
|
||||
|
||||
type BackupCleaner struct {
|
||||
backupRepository *backups_core.BackupRepository
|
||||
storageService *storages.StorageService
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
fieldEncryptor util_encryption.FieldEncryptor
|
||||
logger *slog.Logger
|
||||
backupRemoveListeners []backups_core.BackupRemoveListener
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) Run(ctx context.Context) {
|
||||
wasAlreadyRun := c.hasRun.Load()
|
||||
|
||||
c.runOnce.Do(func() {
|
||||
c.hasRun.Store(true)
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(cleanerTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := c.cleanOldBackups(); err != nil {
|
||||
c.logger.Error("Failed to clean old backups", "error", err)
|
||||
}
|
||||
|
||||
if err := c.cleanExceededBackups(); err != nil {
|
||||
c.logger.Error("Failed to clean exceeded backups", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", c))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) DeleteBackup(backup *backups_core.Backup) error {
|
||||
for _, listener := range c.backupRemoveListeners {
|
||||
if err := listener.OnBeforeBackupRemove(backup); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
storage, err := c.storageService.GetStorageByID(backup.StorageID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = storage.DeleteFile(c.fieldEncryptor, backup.ID)
|
||||
if err != nil {
|
||||
// we do not return error here, because sometimes clean up performed
|
||||
// before unavailable storage removal or change - therefore we should
|
||||
// proceed even in case of error. It's possible that some S3 or
|
||||
// storage is not available yet, it should not block us
|
||||
c.logger.Error("Failed to delete backup file", "error", err)
|
||||
}
|
||||
|
||||
return c.backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) AddBackupRemoveListener(listener backups_core.BackupRemoveListener) {
|
||||
c.backupRemoveListeners = append(c.backupRemoveListeners, listener)
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanOldBackups() error {
|
||||
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
backupStorePeriod := backupConfig.StorePeriod
|
||||
|
||||
if backupStorePeriod == period.PeriodForever {
|
||||
continue
|
||||
}
|
||||
|
||||
storeDuration := backupStorePeriod.ToDuration()
|
||||
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
|
||||
|
||||
oldBackups, err := c.backupRepository.FindBackupsBeforeDate(
|
||||
backupConfig.DatabaseID,
|
||||
dateBeforeBackupsShouldBeDeleted,
|
||||
)
|
||||
if err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to find old backups for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, backup := range oldBackups {
|
||||
if err := c.DeleteBackup(backup); err != nil {
|
||||
c.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
c.logger.Info(
|
||||
"Deleted old backup",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanExceededBackups() error {
|
||||
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
if backupConfig.MaxBackupsTotalSizeMB <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := c.cleanExceededBackupsForDatabase(
|
||||
backupConfig.DatabaseID,
|
||||
backupConfig.MaxBackupsTotalSizeMB,
|
||||
); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to clean exceeded backups for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *BackupCleaner) cleanExceededBackupsForDatabase(
|
||||
databaseID uuid.UUID,
|
||||
limitperDbMB int64,
|
||||
) error {
|
||||
for {
|
||||
backupsTotalSizeMB, err := c.backupRepository.GetTotalSizeByDatabase(databaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if backupsTotalSizeMB <= float64(limitperDbMB) {
|
||||
break
|
||||
}
|
||||
|
||||
oldestBackups, err := c.backupRepository.FindOldestByDatabaseExcludingInProgress(
|
||||
databaseID,
|
||||
1,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(oldestBackups) == 0 {
|
||||
c.logger.Warn(
|
||||
"No backups to delete but still over limit",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"totalSizeMB",
|
||||
backupsTotalSizeMB,
|
||||
"limitMB",
|
||||
limitperDbMB,
|
||||
)
|
||||
break
|
||||
}
|
||||
|
||||
backup := oldestBackups[0]
|
||||
if err := c.DeleteBackup(backup); err != nil {
|
||||
c.logger.Error(
|
||||
"Failed to delete exceeded backup",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
c.logger.Info(
|
||||
"Deleted exceeded backup",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"backupSizeMB",
|
||||
backup.BackupSizeMb,
|
||||
"totalSizeMB",
|
||||
backupsTotalSizeMB,
|
||||
"limitMB",
|
||||
limitperDbMB,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,491 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/storage"
|
||||
"databasus-backend/internal/util/period"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_CleanOldBackups_DeletesBackupsOlderThanStorePeriod(t *testing.T) {
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
// Create backup interval
|
||||
interval := createTestInterval()
|
||||
|
||||
backupConfig := &backups_config.BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create backups with different ages
|
||||
now := time.Now().UTC()
|
||||
oldBackup1 := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10,
|
||||
CreatedAt: now.Add(-10 * 24 * time.Hour), // 10 days old
|
||||
}
|
||||
oldBackup2 := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10,
|
||||
CreatedAt: now.Add(-8 * 24 * time.Hour), // 8 days old
|
||||
}
|
||||
recentBackup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10,
|
||||
CreatedAt: now.Add(-3 * 24 * time.Hour), // 3 days old
|
||||
}
|
||||
|
||||
err = backupRepository.Save(oldBackup1)
|
||||
assert.NoError(t, err)
|
||||
err = backupRepository.Save(oldBackup2)
|
||||
assert.NoError(t, err)
|
||||
err = backupRepository.Save(recentBackup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Run cleanup
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanOldBackups()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify old backups deleted, recent backup remains
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(remainingBackups))
|
||||
assert.Equal(t, recentBackup.ID, remainingBackups[0].ID)
|
||||
}
|
||||
|
||||
func Test_CleanOldBackups_SkipsDatabaseWithForeverStorePeriod(t *testing.T) {
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
// Create backup interval
|
||||
interval := createTestInterval()
|
||||
|
||||
backupConfig := &backups_config.BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create very old backup
|
||||
oldBackup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10,
|
||||
CreatedAt: time.Now().UTC().Add(-365 * 24 * time.Hour), // 1 year old
|
||||
}
|
||||
err = backupRepository.Save(oldBackup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Run cleanup
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanOldBackups()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify backup still exists
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(remainingBackups))
|
||||
assert.Equal(t, oldBackup.ID, remainingBackups[0].ID)
|
||||
}
|
||||
|
||||
func Test_CleanExceededBackups_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
// Create backup interval
|
||||
interval := createTestInterval()
|
||||
|
||||
backupConfig := &backups_config.BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
MaxBackupsTotalSizeMB: 100, // 100 MB limit
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create 3 backups totaling 50MB (under limit)
|
||||
for i := 0; i < 3; i++ {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 16.67,
|
||||
CreatedAt: time.Now().UTC().Add(-time.Duration(i) * time.Hour),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Run cleanup
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanExceededBackups()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify all backups remain
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, len(remainingBackups))
|
||||
}
|
||||
|
||||
func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T) {
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
// Create backup interval
|
||||
interval := createTestInterval()
|
||||
|
||||
backupConfig := &backups_config.BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
MaxBackupsTotalSizeMB: 30, // 30 MB limit
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create 5 backups of 10MB each (total 50MB, over 30MB limit)
|
||||
now := time.Now().UTC()
|
||||
var backupIDs []uuid.UUID
|
||||
for i := 0; i < 5; i++ {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10,
|
||||
CreatedAt: now.Add(-time.Duration(4-i) * time.Hour), // Oldest first
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
backupIDs = append(backupIDs, backup.ID)
|
||||
}
|
||||
|
||||
// Run cleanup
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanExceededBackups()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify 2 oldest backups deleted, 3 newest remain
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, len(remainingBackups))
|
||||
|
||||
// Check that the newest 3 backups remain
|
||||
remainingIDs := make(map[uuid.UUID]bool)
|
||||
for _, backup := range remainingBackups {
|
||||
remainingIDs[backup.ID] = true
|
||||
}
|
||||
assert.False(t, remainingIDs[backupIDs[0]]) // Oldest deleted
|
||||
assert.False(t, remainingIDs[backupIDs[1]]) // 2nd oldest deleted
|
||||
assert.True(t, remainingIDs[backupIDs[2]]) // 3rd remains
|
||||
assert.True(t, remainingIDs[backupIDs[3]]) // 4th remains
|
||||
assert.True(t, remainingIDs[backupIDs[4]]) // Newest remains
|
||||
}
|
||||
|
||||
func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
// Create backup interval
|
||||
interval := createTestInterval()
|
||||
|
||||
backupConfig := &backups_config.BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
MaxBackupsTotalSizeMB: 50, // 50 MB limit
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
// Create 3 completed backups of 30MB each
|
||||
completedBackups := make([]*backups_core.Backup, 3)
|
||||
for i := 0; i < 3; i++ {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 30,
|
||||
CreatedAt: now.Add(-time.Duration(3-i) * time.Hour),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
completedBackups[i] = backup
|
||||
}
|
||||
|
||||
// Create 1 in-progress backup (should be excluded from size calculation and deletion)
|
||||
inProgressBackup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
BackupSizeMb: 10,
|
||||
CreatedAt: now,
|
||||
}
|
||||
err = backupRepository.Save(inProgressBackup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Run cleanup
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanExceededBackups()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify: only completed backups deleted, in-progress remains
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Should have in-progress + 1 completed (total 40MB completed + 10MB in-progress)
|
||||
assert.GreaterOrEqual(t, len(remainingBackups), 2)
|
||||
|
||||
// Verify in-progress backup still exists
|
||||
var inProgressFound bool
|
||||
for _, backup := range remainingBackups {
|
||||
if backup.ID == inProgressBackup.ID {
|
||||
inProgressFound = true
|
||||
assert.Equal(t, backups_core.BackupStatusInProgress, backup.Status)
|
||||
}
|
||||
}
|
||||
assert.True(t, inProgressFound, "In-progress backup should not be deleted")
|
||||
}
|
||||
|
||||
func Test_CleanExceededBackups_WithZeroLimit_SkipsDatabase(t *testing.T) {
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
// Create backup interval
|
||||
interval := createTestInterval()
|
||||
|
||||
backupConfig := &backups_config.BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodForever,
|
||||
StorageID: &storage.ID,
|
||||
MaxBackupsTotalSizeMB: 0, // No size limit
|
||||
BackupIntervalID: interval.ID,
|
||||
BackupInterval: interval,
|
||||
}
|
||||
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create large backups
|
||||
for i := 0; i < 10; i++ {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 100,
|
||||
CreatedAt: time.Now().UTC().Add(-time.Duration(i) * time.Hour),
|
||||
}
|
||||
err = backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Run cleanup
|
||||
cleaner := GetBackupCleaner()
|
||||
err = cleaner.cleanExceededBackups()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify all backups remain
|
||||
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 10, len(remainingBackups))
|
||||
}
|
||||
|
||||
func Test_GetTotalSizeByDatabase_CalculatesCorrectly(t *testing.T) {
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
// Create completed backups
|
||||
completedBackup1 := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
completedBackup2 := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 20.3,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
// Create failed backup (should be included)
|
||||
failedBackup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusFailed,
|
||||
BackupSizeMb: 5.2,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
// Create in-progress backup (should be excluded)
|
||||
inProgressBackup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
BackupSizeMb: 100,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
err := backupRepository.Save(completedBackup1)
|
||||
assert.NoError(t, err)
|
||||
err = backupRepository.Save(completedBackup2)
|
||||
assert.NoError(t, err)
|
||||
err = backupRepository.Save(failedBackup)
|
||||
assert.NoError(t, err)
|
||||
err = backupRepository.Save(inProgressBackup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Calculate total size
|
||||
totalSize, err := backupRepository.GetTotalSizeByDatabase(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Should be 10.5 + 20.3 + 5.2 = 36.0 (excluding in-progress 100)
|
||||
assert.InDelta(t, 36.0, totalSize, 0.1)
|
||||
}
|
||||
|
||||
// Mock listener for testing
|
||||
type mockBackupRemoveListener struct {
|
||||
onBeforeBackupRemove func(*backups_core.Backup) error
|
||||
}
|
||||
|
||||
func (m *mockBackupRemoveListener) OnBeforeBackupRemove(backup *backups_core.Backup) error {
|
||||
if m.onBeforeBackupRemove != nil {
|
||||
return m.onBeforeBackupRemove(backup)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Test_DeleteBackup_WhenStorageDeleteFails_BackupStillRemovedFromDatabase verifies resilience
|
||||
// when storage becomes unavailable. Even if storage.DeleteFile fails (e.g., storage is offline,
|
||||
// credentials changed, or storage was deleted), the backup record should still be removed from
|
||||
// the database. This prevents orphaned backup records when storage is no longer accessible.
|
||||
func Test_DeleteBackup_WhenStorageDeleteFails_BackupStillRemovedFromDatabase(t *testing.T) {
|
||||
router := CreateTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
testStorage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, testStorage, notifier)
|
||||
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: testStorage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err := backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
cleaner := GetBackupCleaner()
|
||||
|
||||
err = cleaner.DeleteBackup(backup)
|
||||
assert.NoError(t, err, "DeleteBackup should succeed even when storage file doesn't exist")
|
||||
|
||||
deletedBackup, err := backupRepository.FindByID(backup.ID)
|
||||
assert.Error(t, err, "Backup should not exist in database")
|
||||
assert.Nil(t, deletedBackup)
|
||||
}
|
||||
|
||||
func createTestInterval() *intervals.Interval {
|
||||
timeOfDay := "04:00"
|
||||
interval := &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
|
||||
err := storage.GetDb().Create(interval).Error
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return interval
|
||||
}
|
||||
@@ -1,71 +1,83 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/usecases"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var backupRepository = &backups_core.BackupRepository{}
|
||||
|
||||
var backupCancelManager = backups_cancellation.GetBackupCancelManager()
|
||||
var taskCancelManager = tasks_cancellation.GetTaskCancelManager()
|
||||
|
||||
var nodesRegistry = &BackupNodesRegistry{
|
||||
cache_utils.GetValkeyClient(),
|
||||
logger.GetLogger(),
|
||||
cache_utils.DefaultCacheTimeout,
|
||||
cache_utils.NewPubSubManager(),
|
||||
cache_utils.NewPubSubManager(),
|
||||
var backupCleaner = &BackupCleaner{
|
||||
backupRepository: backupRepository,
|
||||
storageService: storages.GetStorageService(),
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
logger: logger.GetLogger(),
|
||||
backupRemoveListeners: []backups_core.BackupRemoveListener{},
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
|
||||
var backupNodesRegistry = &BackupNodesRegistry{
|
||||
client: cache_utils.GetValkeyClient(),
|
||||
logger: logger.GetLogger(),
|
||||
timeout: cache_utils.DefaultCacheTimeout,
|
||||
pubsubBackups: cache_utils.NewPubSubManager(),
|
||||
pubsubCompletions: cache_utils.NewPubSubManager(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
|
||||
func getNodeID() uuid.UUID {
|
||||
nodeIDStr := config.GetEnv().NodeID
|
||||
nodeID, err := uuid.Parse(nodeIDStr)
|
||||
if err != nil {
|
||||
logger.GetLogger().Error("Failed to parse node ID from config", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
return nodeID
|
||||
return uuid.New()
|
||||
}
|
||||
|
||||
var backuperNode = &BackuperNode{
|
||||
databases.GetDatabaseService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
backupCancelManager,
|
||||
nodesRegistry,
|
||||
logger.GetLogger(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
getNodeID(),
|
||||
time.Time{},
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
workspaceService: workspaces_services.GetWorkspaceService(),
|
||||
backupRepository: backupRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
notificationSender: notifiers.GetNotifierService(),
|
||||
backupCancelManager: taskCancelManager,
|
||||
backupNodesRegistry: backupNodesRegistry,
|
||||
logger: logger.GetLogger(),
|
||||
createBackupUseCase: usecases.GetCreateBackupUsecase(),
|
||||
nodeID: getNodeID(),
|
||||
lastHeartbeat: time.Time{},
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
|
||||
var backupsScheduler = &BackupsScheduler{
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
backupCancelManager,
|
||||
nodesRegistry,
|
||||
time.Now().UTC(),
|
||||
logger.GetLogger(),
|
||||
make(map[uuid.UUID]BackupToNodeRelation),
|
||||
backuperNode,
|
||||
backupRepository: backupRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
taskCancelManager: taskCancelManager,
|
||||
backupNodesRegistry: backupNodesRegistry,
|
||||
lastBackupTime: time.Now().UTC(),
|
||||
logger: logger.GetLogger(),
|
||||
backupToNodeRelations: make(map[uuid.UUID]BackupToNodeRelation),
|
||||
backuperNode: backuperNode,
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
|
||||
func GetBackupsScheduler() *BackupsScheduler {
|
||||
@@ -75,3 +87,11 @@ func GetBackupsScheduler() *BackupsScheduler {
|
||||
func GetBackuperNode() *BackuperNode {
|
||||
return backuperNode
|
||||
}
|
||||
|
||||
func GetBackupNodesRegistry() *BackupNodesRegistry {
|
||||
return backupNodesRegistry
|
||||
}
|
||||
|
||||
func GetBackupCleaner() *BackupCleaner {
|
||||
return backupCleaner
|
||||
}
|
||||
|
||||
@@ -6,6 +6,11 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type BackupToNodeRelation struct {
|
||||
NodeID uuid.UUID `json:"nodeId"`
|
||||
BackupsIDs []uuid.UUID `json:"backupsIds"`
|
||||
}
|
||||
|
||||
type BackupNode struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
ThroughputMBs int `json:"throughputMBs"`
|
||||
@@ -18,17 +23,12 @@ type BackupNodeStats struct {
|
||||
}
|
||||
|
||||
type BackupSubmitMessage struct {
|
||||
NodeID string `json:"nodeId"`
|
||||
BackupID string `json:"backupId"`
|
||||
IsCallNotifier bool `json:"isCallNotifier"`
|
||||
NodeID uuid.UUID `json:"nodeId"`
|
||||
BackupID uuid.UUID `json:"backupId"`
|
||||
IsCallNotifier bool `json:"isCallNotifier"`
|
||||
}
|
||||
|
||||
type BackupCompletionMessage struct {
|
||||
NodeID string `json:"nodeId"`
|
||||
BackupID string `json:"backupId"`
|
||||
}
|
||||
|
||||
type BackupToNodeRelation struct {
|
||||
NodeID uuid.UUID `json:"nodeId"`
|
||||
BackupsIDs []uuid.UUID `json:"backupsIds"`
|
||||
NodeID uuid.UUID `json:"nodeId"`
|
||||
BackupID uuid.UUID `json:"backupId"`
|
||||
}
|
||||
|
||||
@@ -1,8 +1,18 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
common "databasus-backend/internal/features/backups/backups/common"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
@@ -17,3 +27,168 @@ func (m *MockNotificationSender) SendNotification(
|
||||
) {
|
||||
m.Called(notifier, title, message)
|
||||
}
|
||||
|
||||
type CreateFailedBackupUsecase struct{}
|
||||
|
||||
func (uc *CreateFailedBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
backupProgressListener(10)
|
||||
return nil, errors.New("backup failed")
|
||||
}
|
||||
|
||||
type CreateSuccessBackupUsecase struct{}
|
||||
|
||||
func (uc *CreateSuccessBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
backupProgressListener(10)
|
||||
return &common.BackupMetadata{
|
||||
EncryptionSalt: nil,
|
||||
EncryptionIV: nil,
|
||||
Encryption: backups_config.BackupEncryptionNone,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateLargeBackupUsecase simulates a large backup (10000 MB)
|
||||
type CreateLargeBackupUsecase struct{}
|
||||
|
||||
func (uc *CreateLargeBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
backupProgressListener(10000)
|
||||
return &common.BackupMetadata{
|
||||
EncryptionSalt: nil,
|
||||
EncryptionIV: nil,
|
||||
Encryption: backups_config.BackupEncryptionNone,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateProgressiveBackupUsecase simulates progressive size updates that exceed limit
|
||||
type CreateProgressiveBackupUsecase struct{}
|
||||
|
||||
func (uc *CreateProgressiveBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
// Simulate progressive backup that grows beyond limit
|
||||
backupProgressListener(1)
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
backupProgressListener(3)
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
backupProgressListener(5)
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
backupProgressListener(10) // This exceeds the 5 MB limit
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
// Should not reach here due to cancellation
|
||||
return &common.BackupMetadata{
|
||||
EncryptionSalt: nil,
|
||||
EncryptionIV: nil,
|
||||
Encryption: backups_config.BackupEncryptionNone,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateMediumBackupUsecase simulates a 50 MB backup
|
||||
type CreateMediumBackupUsecase struct{}
|
||||
|
||||
func (uc *CreateMediumBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
backupProgressListener(50)
|
||||
return &common.BackupMetadata{
|
||||
EncryptionSalt: nil,
|
||||
EncryptionIV: nil,
|
||||
Encryption: backups_config.BackupEncryptionNone,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// MockTrackingBackupUsecase tracks backup use case calls for testing parallel execution
|
||||
type MockTrackingBackupUsecase struct {
|
||||
callCount atomic.Int32
|
||||
calledBackupIDs chan uuid.UUID
|
||||
}
|
||||
|
||||
func NewMockTrackingBackupUsecase() *MockTrackingBackupUsecase {
|
||||
return &MockTrackingBackupUsecase{
|
||||
calledBackupIDs: make(chan uuid.UUID, 10),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockTrackingBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
m.callCount.Add(1)
|
||||
|
||||
// Send backup ID to channel (non-blocking)
|
||||
select {
|
||||
case m.calledBackupIDs <- backupID:
|
||||
default:
|
||||
}
|
||||
|
||||
// Simulate backup work
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
backupProgressListener(10)
|
||||
|
||||
return &common.BackupMetadata{
|
||||
EncryptionSalt: nil,
|
||||
EncryptionIV: nil,
|
||||
Encryption: backups_config.BackupEncryptionNone,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MockTrackingBackupUsecase) GetCallCount() int32 {
|
||||
return m.callCount.Load()
|
||||
}
|
||||
|
||||
func (m *MockTrackingBackupUsecase) GetCalledBackupIDs() []uuid.UUID {
|
||||
ids := []uuid.UUID{}
|
||||
for {
|
||||
select {
|
||||
case id := <-m.calledBackupIDs:
|
||||
ids = append(ids, id)
|
||||
default:
|
||||
return ids
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
@@ -15,25 +17,70 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
nodeInfoKeyPrefix = "node:"
|
||||
nodeInfoKeyPrefix = "backup:node:"
|
||||
nodeInfoKeySuffix = ":info"
|
||||
nodeActiveBackupsPrefix = "node:"
|
||||
nodeActiveBackupsPrefix = "backup:node:"
|
||||
nodeActiveBackupsSuffix = ":active_backups"
|
||||
backupSubmitChannel = "backup:submit"
|
||||
backupCompletionChannel = "backup:completion"
|
||||
|
||||
deadNodeThreshold = 2 * time.Minute
|
||||
cleanupTickerInterval = 1 * time.Second
|
||||
)
|
||||
|
||||
// BackupNodesRegistry helps to sync backups scheduler and backup nodes.
|
||||
//
|
||||
// Features:
|
||||
// - Track node availability and load level
|
||||
// - Assign from scheduler to node backups needed to be processed
|
||||
// - Notify scheduler from node about backup completion
|
||||
//
|
||||
// Important things to remember:
|
||||
// - Nodes without heartbeat for more than 2 minutes are not included
|
||||
// in available nodes list and stats
|
||||
//
|
||||
// Cleanup dead nodes performed on 2 levels:
|
||||
// - List and stats functions do not return dead nodes
|
||||
// - Periodically dead nodes are cleaned up in cache (to not
|
||||
// accumulate too many dead nodes in cache)
|
||||
type BackupNodesRegistry struct {
|
||||
client valkey.Client
|
||||
logger *slog.Logger
|
||||
timeout time.Duration
|
||||
pubsubBackups *cache_utils.PubSubManager
|
||||
pubsubCompletions *cache_utils.PubSubManager
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) Run(ctx context.Context) {
|
||||
wasAlreadyRun := r.hasRun.Load()
|
||||
|
||||
r.runOnce.Do(func() {
|
||||
r.hasRun.Store(true)
|
||||
|
||||
if err := r.cleanupDeadNodes(); err != nil {
|
||||
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(cleanupTickerInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := r.cleanupDeadNodes(); err != nil {
|
||||
r.logger.Error("Failed to cleanup dead nodes", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", r))
|
||||
}
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) GetAvailableNodes() ([]BackupNode, error) {
|
||||
@@ -76,13 +123,30 @@ func (r *BackupNodesRegistry) GetAvailableNodes() ([]BackupNode, error) {
|
||||
return nil, fmt.Errorf("failed to pipeline get node keys: %w", err)
|
||||
}
|
||||
|
||||
threshold := time.Now().UTC().Add(-deadNodeThreshold)
|
||||
var nodes []BackupNode
|
||||
|
||||
for key, data := range keyDataMap {
|
||||
// Skip if the key doesn't exist (data is empty)
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var node BackupNode
|
||||
if err := json.Unmarshal(data, &node); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal node data", "key", key, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip nodes with zero/uninitialized heartbeat
|
||||
if node.LastHeartbeat.IsZero() {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.LastHeartbeat.Before(threshold) {
|
||||
continue
|
||||
}
|
||||
|
||||
nodes = append(nodes, node)
|
||||
}
|
||||
|
||||
@@ -129,18 +193,54 @@ func (r *BackupNodesRegistry) GetBackupNodesStats() ([]BackupNodeStats, error) {
|
||||
return nil, fmt.Errorf("failed to pipeline get active backups keys: %w", err)
|
||||
}
|
||||
|
||||
var stats []BackupNodeStats
|
||||
for key, data := range keyDataMap {
|
||||
var nodeInfoKeys []string
|
||||
nodeIDToStatsKey := make(map[string]string)
|
||||
for key := range keyDataMap {
|
||||
nodeID := r.extractNodeIDFromKey(key, nodeActiveBackupsPrefix, nodeActiveBackupsSuffix)
|
||||
nodeIDStr := nodeID.String()
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeIDStr, nodeInfoKeySuffix)
|
||||
nodeInfoKeys = append(nodeInfoKeys, infoKey)
|
||||
nodeIDToStatsKey[infoKey] = key
|
||||
}
|
||||
|
||||
count, err := r.parseIntFromBytes(data)
|
||||
nodeInfoMap, err := r.pipelineGetKeys(nodeInfoKeys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pipeline get node info keys: %w", err)
|
||||
}
|
||||
|
||||
threshold := time.Now().UTC().Add(-deadNodeThreshold)
|
||||
var stats []BackupNodeStats
|
||||
for infoKey, nodeData := range nodeInfoMap {
|
||||
// Skip if the info key doesn't exist (nodeData is empty)
|
||||
if len(nodeData) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var node BackupNode
|
||||
if err := json.Unmarshal(nodeData, &node); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal node data", "key", infoKey, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip nodes with zero/uninitialized heartbeat
|
||||
if node.LastHeartbeat.IsZero() {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.LastHeartbeat.Before(threshold) {
|
||||
continue
|
||||
}
|
||||
|
||||
statsKey := nodeIDToStatsKey[infoKey]
|
||||
tasksData := keyDataMap[statsKey]
|
||||
count, err := r.parseIntFromBytes(tasksData)
|
||||
if err != nil {
|
||||
r.logger.Warn("Failed to parse active backups count", "key", key, "error", err)
|
||||
r.logger.Warn("Failed to parse active backups count", "key", statsKey, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
stat := BackupNodeStats{
|
||||
ID: nodeID,
|
||||
ID: node.ID,
|
||||
ActiveBackups: int(count),
|
||||
}
|
||||
stats = append(stats, stat)
|
||||
@@ -149,11 +249,11 @@ func (r *BackupNodesRegistry) GetBackupNodesStats() ([]BackupNodeStats, error) {
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) IncrementBackupsInProgress(nodeID string) error {
|
||||
func (r *BackupNodesRegistry) IncrementBackupsInProgress(nodeID uuid.UUID) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID, nodeActiveBackupsSuffix)
|
||||
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID.String(), nodeActiveBackupsSuffix)
|
||||
result := r.client.Do(ctx, r.client.B().Incr().Key(key).Build())
|
||||
|
||||
if result.Error() != nil {
|
||||
@@ -167,11 +267,11 @@ func (r *BackupNodesRegistry) IncrementBackupsInProgress(nodeID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) DecrementBackupsInProgress(nodeID string) error {
|
||||
func (r *BackupNodesRegistry) DecrementBackupsInProgress(nodeID uuid.UUID) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID, nodeActiveBackupsSuffix)
|
||||
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID.String(), nodeActiveBackupsSuffix)
|
||||
result := r.client.Do(ctx, r.client.B().Decr().Key(key).Build())
|
||||
|
||||
if result.Error() != nil {
|
||||
@@ -198,6 +298,10 @@ func (r *BackupNodesRegistry) DecrementBackupsInProgress(nodeID string) error {
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) HearthbeatNodeInRegistry(now time.Time, backupNode BackupNode) error {
|
||||
if now.IsZero() {
|
||||
return fmt.Errorf("cannot register node with zero heartbeat timestamp")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
@@ -247,7 +351,7 @@ func (r *BackupNodesRegistry) UnregisterNodeFromRegistry(backupNode BackupNode)
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) AssignBackupToNode(
|
||||
targetNodeID string,
|
||||
targetNodeID uuid.UUID,
|
||||
backupID uuid.UUID,
|
||||
isCallNotifier bool,
|
||||
) error {
|
||||
@@ -255,7 +359,7 @@ func (r *BackupNodesRegistry) AssignBackupToNode(
|
||||
|
||||
message := BackupSubmitMessage{
|
||||
NodeID: targetNodeID,
|
||||
BackupID: backupID.String(),
|
||||
BackupID: backupID,
|
||||
IsCallNotifier: isCallNotifier,
|
||||
}
|
||||
|
||||
@@ -273,7 +377,7 @@ func (r *BackupNodesRegistry) AssignBackupToNode(
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) SubscribeNodeForBackupsAssignment(
|
||||
nodeID string,
|
||||
nodeID uuid.UUID,
|
||||
handler func(backupID uuid.UUID, isCallNotifier bool),
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
@@ -289,19 +393,7 @@ func (r *BackupNodesRegistry) SubscribeNodeForBackupsAssignment(
|
||||
return
|
||||
}
|
||||
|
||||
backupID, err := uuid.Parse(msg.BackupID)
|
||||
if err != nil {
|
||||
r.logger.Warn(
|
||||
"Failed to parse backup ID from message",
|
||||
"backupId",
|
||||
msg.BackupID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
handler(backupID, msg.IsCallNotifier)
|
||||
handler(msg.BackupID, msg.IsCallNotifier)
|
||||
}
|
||||
|
||||
err := r.pubsubBackups.Subscribe(ctx, backupSubmitChannel, wrappedHandler)
|
||||
@@ -323,12 +415,12 @@ func (r *BackupNodesRegistry) UnsubscribeNodeForBackupsAssignments() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) PublishBackupCompletion(nodeID string, backupID uuid.UUID) error {
|
||||
func (r *BackupNodesRegistry) PublishBackupCompletion(nodeID uuid.UUID, backupID uuid.UUID) error {
|
||||
ctx := context.Background()
|
||||
|
||||
message := BackupCompletionMessage{
|
||||
NodeID: nodeID,
|
||||
BackupID: backupID.String(),
|
||||
BackupID: backupID,
|
||||
}
|
||||
|
||||
messageJSON, err := json.Marshal(message)
|
||||
@@ -345,7 +437,7 @@ func (r *BackupNodesRegistry) PublishBackupCompletion(nodeID string, backupID uu
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) SubscribeForBackupsCompletions(
|
||||
handler func(nodeID string, backupID uuid.UUID),
|
||||
handler func(nodeID uuid.UUID, backupID uuid.UUID),
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -356,19 +448,7 @@ func (r *BackupNodesRegistry) SubscribeForBackupsCompletions(
|
||||
return
|
||||
}
|
||||
|
||||
backupID, err := uuid.Parse(msg.BackupID)
|
||||
if err != nil {
|
||||
r.logger.Warn(
|
||||
"Failed to parse backup ID from completion message",
|
||||
"backupId",
|
||||
msg.BackupID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
handler(msg.NodeID, backupID)
|
||||
handler(msg.NodeID, msg.BackupID)
|
||||
}
|
||||
|
||||
err := r.pubsubCompletions.Subscribe(ctx, backupCompletionChannel, wrappedHandler)
|
||||
@@ -446,3 +526,108 @@ func (r *BackupNodesRegistry) parseIntFromBytes(data []byte) (int64, error) {
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) cleanupDeadNodes() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
var allKeys []string
|
||||
cursor := uint64(0)
|
||||
pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix
|
||||
|
||||
for {
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to scan node keys: %w", result.Error())
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse scan result: %w", err)
|
||||
}
|
||||
|
||||
allKeys = append(allKeys, scanResult.Elements...)
|
||||
|
||||
cursor = scanResult.Cursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(allKeys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
keyDataMap, err := r.pipelineGetKeys(allKeys)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to pipeline get node keys: %w", err)
|
||||
}
|
||||
|
||||
threshold := time.Now().UTC().Add(-deadNodeThreshold)
|
||||
var deadNodeKeys []string
|
||||
|
||||
for key, data := range keyDataMap {
|
||||
// Skip if the key doesn't exist (data is empty)
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var node BackupNode
|
||||
if err := json.Unmarshal(data, &node); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal node data during cleanup", "key", key, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip nodes with zero/uninitialized heartbeat
|
||||
if node.LastHeartbeat.IsZero() {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.LastHeartbeat.Before(threshold) {
|
||||
nodeID := node.ID.String()
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeID, nodeInfoKeySuffix)
|
||||
statsKey := fmt.Sprintf(
|
||||
"%s%s%s",
|
||||
nodeActiveBackupsPrefix,
|
||||
nodeID,
|
||||
nodeActiveBackupsSuffix,
|
||||
)
|
||||
|
||||
deadNodeKeys = append(deadNodeKeys, infoKey, statsKey)
|
||||
r.logger.Info(
|
||||
"Marking node for cleanup",
|
||||
"nodeID", nodeID,
|
||||
"lastHeartbeat", node.LastHeartbeat,
|
||||
"threshold", threshold,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if len(deadNodeKeys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
delCtx, delCancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer delCancel()
|
||||
|
||||
result := r.client.Do(
|
||||
delCtx,
|
||||
r.client.B().Del().Key(deadNodeKeys...).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to delete dead node keys: %w", result.Error())
|
||||
}
|
||||
|
||||
deletedCount, err := result.AsInt64()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse deleted count: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Cleaned up dead nodes", "deletedKeysCount", deletedCount)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,10 @@ package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -36,7 +40,7 @@ func Test_UnregisterNodeFromRegistry_RemovesNodeAndCounter(t *testing.T) {
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node.ID.String())
|
||||
err = registry.IncrementBackupsInProgress(node.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.UnregisterNodeFromRegistry(node)
|
||||
@@ -100,7 +104,7 @@ func Test_IncrementBackupsInProgress_IncrementsCounter(t *testing.T) {
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node.ID.String())
|
||||
err = registry.IncrementBackupsInProgress(node.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err := registry.GetBackupNodesStats()
|
||||
@@ -109,7 +113,7 @@ func Test_IncrementBackupsInProgress_IncrementsCounter(t *testing.T) {
|
||||
assert.Equal(t, node.ID, stats[0].ID)
|
||||
assert.Equal(t, 1, stats[0].ActiveBackups)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node.ID.String())
|
||||
err = registry.IncrementBackupsInProgress(node.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err = registry.GetBackupNodesStats()
|
||||
@@ -127,25 +131,25 @@ func Test_DecrementBackupsInProgress_DecrementsCounter(t *testing.T) {
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node.ID.String())
|
||||
err = registry.IncrementBackupsInProgress(node.ID)
|
||||
assert.NoError(t, err)
|
||||
err = registry.IncrementBackupsInProgress(node.ID.String())
|
||||
err = registry.IncrementBackupsInProgress(node.ID)
|
||||
assert.NoError(t, err)
|
||||
err = registry.IncrementBackupsInProgress(node.ID.String())
|
||||
err = registry.IncrementBackupsInProgress(node.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err := registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, stats[0].ActiveBackups)
|
||||
|
||||
err = registry.DecrementBackupsInProgress(node.ID.String())
|
||||
err = registry.DecrementBackupsInProgress(node.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err = registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, stats[0].ActiveBackups)
|
||||
|
||||
err = registry.DecrementBackupsInProgress(node.ID.String())
|
||||
err = registry.DecrementBackupsInProgress(node.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err = registry.GetBackupNodesStats()
|
||||
@@ -162,7 +166,7 @@ func Test_DecrementBackupsInProgress_WhenNegative_ResetsToZero(t *testing.T) {
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.DecrementBackupsInProgress(node.ID.String())
|
||||
err = registry.DecrementBackupsInProgress(node.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err := registry.GetBackupNodesStats()
|
||||
@@ -188,19 +192,19 @@ func Test_GetBackupNodesStats_ReturnsStatsForAllNodes(t *testing.T) {
|
||||
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node1.ID.String())
|
||||
err = registry.IncrementBackupsInProgress(node1.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node2.ID.String())
|
||||
err = registry.IncrementBackupsInProgress(node2.ID)
|
||||
assert.NoError(t, err)
|
||||
err = registry.IncrementBackupsInProgress(node2.ID.String())
|
||||
err = registry.IncrementBackupsInProgress(node2.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node3.ID.String())
|
||||
err = registry.IncrementBackupsInProgress(node3.ID)
|
||||
assert.NoError(t, err)
|
||||
err = registry.IncrementBackupsInProgress(node3.ID.String())
|
||||
err = registry.IncrementBackupsInProgress(node3.ID)
|
||||
assert.NoError(t, err)
|
||||
err = registry.IncrementBackupsInProgress(node3.ID.String())
|
||||
err = registry.IncrementBackupsInProgress(node3.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err := registry.GetBackupNodesStats()
|
||||
@@ -274,12 +278,12 @@ func Test_BackupCounters_TrackedSeparatelyPerNode(t *testing.T) {
|
||||
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node1.ID.String())
|
||||
err = registry.IncrementBackupsInProgress(node1.ID)
|
||||
assert.NoError(t, err)
|
||||
err = registry.IncrementBackupsInProgress(node1.ID.String())
|
||||
err = registry.IncrementBackupsInProgress(node1.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node2.ID.String())
|
||||
err = registry.IncrementBackupsInProgress(node2.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err := registry.GetBackupNodesStats()
|
||||
@@ -294,7 +298,7 @@ func Test_BackupCounters_TrackedSeparatelyPerNode(t *testing.T) {
|
||||
assert.Equal(t, 2, statsMap[node1.ID])
|
||||
assert.Equal(t, 1, statsMap[node2.ID])
|
||||
|
||||
err = registry.DecrementBackupsInProgress(node1.ID.String())
|
||||
err = registry.DecrementBackupsInProgress(node1.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err = registry.GetBackupNodesStats()
|
||||
@@ -366,13 +370,239 @@ func Test_HearthbeatNodeInRegistry_UpdatesLastHeartbeat(t *testing.T) {
|
||||
assert.True(t, nodes[0].LastHeartbeat.After(originalHeartbeat))
|
||||
}
|
||||
|
||||
func Test_HearthbeatNodeInRegistry_RejectsZeroTimestamp(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Time{}, node)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "zero heartbeat timestamp")
|
||||
|
||||
nodes, err := registry.GetAvailableNodes()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, nodes, 0)
|
||||
}
|
||||
|
||||
func Test_GetAvailableNodes_ExcludesStaleNodesFromCache(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node1 := createTestBackupNode()
|
||||
node2 := createTestBackupNode()
|
||||
node3 := createTestBackupNode()
|
||||
defer cleanupTestNode(registry, node1)
|
||||
defer cleanupTestNode(registry, node2)
|
||||
defer cleanupTestNode(registry, node3)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
|
||||
assert.NoError(t, err)
|
||||
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
|
||||
assert.NoError(t, err)
|
||||
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix)
|
||||
result := registry.client.Do(ctx, registry.client.B().Get().Key(key).Build())
|
||||
assert.NoError(t, result.Error())
|
||||
|
||||
data, err := result.AsBytes()
|
||||
assert.NoError(t, err)
|
||||
|
||||
var node BackupNode
|
||||
err = json.Unmarshal(data, &node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
node.LastHeartbeat = time.Now().UTC().Add(-3 * time.Minute)
|
||||
modifiedData, err := json.Marshal(node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
setCtx, setCancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
defer setCancel()
|
||||
setResult := registry.client.Do(
|
||||
setCtx,
|
||||
registry.client.B().Set().Key(key).Value(string(modifiedData)).Build(),
|
||||
)
|
||||
assert.NoError(t, setResult.Error())
|
||||
|
||||
nodes, err := registry.GetAvailableNodes()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, nodes, 2)
|
||||
|
||||
nodeIDs := make(map[uuid.UUID]bool)
|
||||
for _, n := range nodes {
|
||||
nodeIDs[n.ID] = true
|
||||
}
|
||||
assert.True(t, nodeIDs[node1.ID])
|
||||
assert.False(t, nodeIDs[node2.ID])
|
||||
assert.True(t, nodeIDs[node3.ID])
|
||||
}
|
||||
|
||||
func Test_GetBackupNodesStats_ExcludesStaleNodesFromCache(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node1 := createTestBackupNode()
|
||||
node2 := createTestBackupNode()
|
||||
node3 := createTestBackupNode()
|
||||
defer cleanupTestNode(registry, node1)
|
||||
defer cleanupTestNode(registry, node2)
|
||||
defer cleanupTestNode(registry, node3)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
|
||||
assert.NoError(t, err)
|
||||
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
|
||||
assert.NoError(t, err)
|
||||
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node1.ID)
|
||||
assert.NoError(t, err)
|
||||
err = registry.IncrementBackupsInProgress(node2.ID)
|
||||
assert.NoError(t, err)
|
||||
err = registry.IncrementBackupsInProgress(node3.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix)
|
||||
result := registry.client.Do(ctx, registry.client.B().Get().Key(key).Build())
|
||||
assert.NoError(t, result.Error())
|
||||
|
||||
data, err := result.AsBytes()
|
||||
assert.NoError(t, err)
|
||||
|
||||
var node BackupNode
|
||||
err = json.Unmarshal(data, &node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
node.LastHeartbeat = time.Now().UTC().Add(-3 * time.Minute)
|
||||
modifiedData, err := json.Marshal(node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
setCtx, setCancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
defer setCancel()
|
||||
setResult := registry.client.Do(
|
||||
setCtx,
|
||||
registry.client.B().Set().Key(key).Value(string(modifiedData)).Build(),
|
||||
)
|
||||
assert.NoError(t, setResult.Error())
|
||||
|
||||
stats, err := registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, stats, 2)
|
||||
|
||||
statsMap := make(map[uuid.UUID]int)
|
||||
for _, stat := range stats {
|
||||
statsMap[stat.ID] = stat.ActiveBackups
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, statsMap[node1.ID])
|
||||
_, hasNode2 := statsMap[node2.ID]
|
||||
assert.False(t, hasNode2)
|
||||
assert.Equal(t, 1, statsMap[node3.ID])
|
||||
}
|
||||
|
||||
func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
registry := createTestRegistry()
|
||||
node1 := createTestBackupNode()
|
||||
node2 := createTestBackupNode()
|
||||
defer cleanupTestNode(registry, node1)
|
||||
defer cleanupTestNode(registry, node2)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
|
||||
assert.NoError(t, err)
|
||||
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node1.ID)
|
||||
assert.NoError(t, err)
|
||||
err = registry.IncrementBackupsInProgress(node2.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix)
|
||||
result := registry.client.Do(ctx, registry.client.B().Get().Key(key).Build())
|
||||
assert.NoError(t, result.Error())
|
||||
|
||||
data, err := result.AsBytes()
|
||||
assert.NoError(t, err)
|
||||
|
||||
var node BackupNode
|
||||
err = json.Unmarshal(data, &node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
node.LastHeartbeat = time.Now().UTC().Add(-3 * time.Minute)
|
||||
modifiedData, err := json.Marshal(node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
setCtx, setCancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
defer setCancel()
|
||||
setResult := registry.client.Do(
|
||||
setCtx,
|
||||
registry.client.B().Set().Key(key).Value(string(modifiedData)).Build(),
|
||||
)
|
||||
assert.NoError(t, setResult.Error())
|
||||
|
||||
err = registry.cleanupDeadNodes()
|
||||
assert.NoError(t, err)
|
||||
|
||||
checkCtx, checkCancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
defer checkCancel()
|
||||
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix)
|
||||
infoResult := registry.client.Do(checkCtx, registry.client.B().Get().Key(infoKey).Build())
|
||||
assert.Error(t, infoResult.Error())
|
||||
|
||||
counterKey := fmt.Sprintf(
|
||||
"%s%s%s",
|
||||
nodeActiveBackupsPrefix,
|
||||
node2.ID.String(),
|
||||
nodeActiveBackupsSuffix,
|
||||
)
|
||||
counterCtx, counterCancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
defer counterCancel()
|
||||
counterResult := registry.client.Do(
|
||||
counterCtx,
|
||||
registry.client.B().Get().Key(counterKey).Build(),
|
||||
)
|
||||
assert.Error(t, counterResult.Error())
|
||||
|
||||
activeInfoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node1.ID.String(), nodeInfoKeySuffix)
|
||||
activeCtx, activeCancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
defer activeCancel()
|
||||
activeResult := registry.client.Do(
|
||||
activeCtx,
|
||||
registry.client.B().Get().Key(activeInfoKey).Build(),
|
||||
)
|
||||
assert.NoError(t, activeResult.Error())
|
||||
|
||||
nodes, err := registry.GetAvailableNodes()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, node1.ID, nodes[0].ID)
|
||||
|
||||
stats, err := registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, stats, 1)
|
||||
assert.Equal(t, node1.ID, stats[0].ID)
|
||||
}
|
||||
|
||||
func createTestRegistry() *BackupNodesRegistry {
|
||||
return &BackupNodesRegistry{
|
||||
cache_utils.GetValkeyClient(),
|
||||
logger.GetLogger(),
|
||||
cache_utils.DefaultCacheTimeout,
|
||||
cache_utils.NewPubSubManager(),
|
||||
cache_utils.NewPubSubManager(),
|
||||
client: cache_utils.GetValkeyClient(),
|
||||
logger: logger.GetLogger(),
|
||||
timeout: cache_utils.DefaultCacheTimeout,
|
||||
pubsubBackups: cache_utils.NewPubSubManager(),
|
||||
pubsubCompletions: cache_utils.NewPubSubManager(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -394,7 +624,7 @@ func Test_AssignBackupTonode_PublishesJsonMessageToChannel(t *testing.T) {
|
||||
node := createTestBackupNode()
|
||||
backupID := uuid.New()
|
||||
|
||||
err := registry.AssignBackupToNode(node.ID.String(), backupID, true)
|
||||
err := registry.AssignBackupToNode(node.ID, backupID, true)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -410,12 +640,12 @@ func Test_SubscribeNodeForBackupsAssignment_ReceivesSubmittedBackupsForMatchingN
|
||||
receivedBackupID <- id
|
||||
}
|
||||
|
||||
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
|
||||
err := registry.SubscribeNodeForBackupsAssignment(node.ID, handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.AssignBackupToNode(node.ID.String(), backupID, true)
|
||||
err = registry.AssignBackupToNode(node.ID, backupID, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -439,12 +669,12 @@ func Test_SubscribeNodeForBackupsAssignment_FiltersOutBackupsForDifferentNode(t
|
||||
receivedBackupID <- id
|
||||
}
|
||||
|
||||
err := registry.SubscribeNodeForBackupsAssignment(node1.ID.String(), handler)
|
||||
err := registry.SubscribeNodeForBackupsAssignment(node1.ID, handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.AssignBackupToNode(node2.ID.String(), backupID, false)
|
||||
err = registry.AssignBackupToNode(node2.ID, backupID, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -467,15 +697,15 @@ func Test_SubscribeNodeForBackupsAssignment_ParsesJsonAndBackupIdCorrectly(t *te
|
||||
receivedBackups <- id
|
||||
}
|
||||
|
||||
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
|
||||
err := registry.SubscribeNodeForBackupsAssignment(node.ID, handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.AssignBackupToNode(node.ID.String(), backupID1, true)
|
||||
err = registry.AssignBackupToNode(node.ID, backupID1, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.AssignBackupToNode(node.ID.String(), backupID2, false)
|
||||
err = registry.AssignBackupToNode(node.ID, backupID2, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
received1 := <-receivedBackups
|
||||
@@ -497,7 +727,7 @@ func Test_SubscribeNodeForBackupsAssignment_HandlesInvalidJson(t *testing.T) {
|
||||
receivedBackupID <- id
|
||||
}
|
||||
|
||||
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
|
||||
err := registry.SubscribeNodeForBackupsAssignment(node.ID, handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
@@ -525,12 +755,12 @@ func Test_UnsubscribeNodeForBackupsAssignments_StopsReceivingMessages(t *testing
|
||||
receivedBackupID <- id
|
||||
}
|
||||
|
||||
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
|
||||
err := registry.SubscribeNodeForBackupsAssignment(node.ID, handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.AssignBackupToNode(node.ID.String(), backupID1, true)
|
||||
err = registry.AssignBackupToNode(node.ID, backupID1, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
received := <-receivedBackupID
|
||||
@@ -541,7 +771,7 @@ func Test_UnsubscribeNodeForBackupsAssignments_StopsReceivingMessages(t *testing
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.AssignBackupToNode(node.ID.String(), backupID2, false)
|
||||
err = registry.AssignBackupToNode(node.ID, backupID2, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -559,10 +789,10 @@ func Test_SubscribeNodeForBackupsAssignment_WhenAlreadySubscribed_ReturnsError(t
|
||||
|
||||
handler := func(id uuid.UUID, isCallNotifier bool) {}
|
||||
|
||||
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
|
||||
err := registry.SubscribeNodeForBackupsAssignment(node.ID, handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
|
||||
err = registry.SubscribeNodeForBackupsAssignment(node.ID, handler)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "already subscribed")
|
||||
}
|
||||
@@ -593,25 +823,25 @@ func Test_MultipleNodes_EachReceivesOnlyTheirBackups(t *testing.T) {
|
||||
handler2 := func(id uuid.UUID, isCallNotifier bool) { receivedBackups2 <- id }
|
||||
handler3 := func(id uuid.UUID, isCallNotifier bool) { receivedBackups3 <- id }
|
||||
|
||||
err := registry1.SubscribeNodeForBackupsAssignment(node1.ID.String(), handler1)
|
||||
err := registry1.SubscribeNodeForBackupsAssignment(node1.ID, handler1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry2.SubscribeNodeForBackupsAssignment(node2.ID.String(), handler2)
|
||||
err = registry2.SubscribeNodeForBackupsAssignment(node2.ID, handler2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry3.SubscribeNodeForBackupsAssignment(node3.ID.String(), handler3)
|
||||
err = registry3.SubscribeNodeForBackupsAssignment(node3.ID, handler3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
submitRegistry := createTestRegistry()
|
||||
err = submitRegistry.AssignBackupToNode(node1.ID.String(), backupID1, true)
|
||||
err = submitRegistry.AssignBackupToNode(node1.ID, backupID1, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = submitRegistry.AssignBackupToNode(node2.ID.String(), backupID2, false)
|
||||
err = submitRegistry.AssignBackupToNode(node2.ID, backupID2, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = submitRegistry.AssignBackupToNode(node3.ID.String(), backupID3, true)
|
||||
err = submitRegistry.AssignBackupToNode(node3.ID, backupID3, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -660,7 +890,7 @@ func Test_PublishBackupCompletion_PublishesMessageToChannel(t *testing.T) {
|
||||
node := createTestBackupNode()
|
||||
backupID := uuid.New()
|
||||
|
||||
err := registry.PublishBackupCompletion(node.ID.String(), backupID)
|
||||
err := registry.PublishBackupCompletion(node.ID, backupID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -672,8 +902,8 @@ func Test_SubscribeForBackupsCompletions_ReceivesCompletedBackups(t *testing.T)
|
||||
defer registry.UnsubscribeForBackupsCompletions()
|
||||
|
||||
receivedBackupID := make(chan uuid.UUID, 1)
|
||||
receivedNodeID := make(chan string, 1)
|
||||
handler := func(nodeID string, backupID uuid.UUID) {
|
||||
receivedNodeID := make(chan uuid.UUID, 1)
|
||||
handler := func(nodeID uuid.UUID, backupID uuid.UUID) {
|
||||
receivedNodeID <- nodeID
|
||||
receivedBackupID <- backupID
|
||||
}
|
||||
@@ -683,12 +913,12 @@ func Test_SubscribeForBackupsCompletions_ReceivesCompletedBackups(t *testing.T)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.PublishBackupCompletion(node.ID.String(), backupID)
|
||||
err = registry.PublishBackupCompletion(node.ID, backupID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case receivedNode := <-receivedNodeID:
|
||||
assert.Equal(t, node.ID.String(), receivedNode)
|
||||
assert.Equal(t, node.ID, receivedNode)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Timeout waiting for node ID")
|
||||
}
|
||||
@@ -710,7 +940,7 @@ func Test_SubscribeForBackupsCompletions_ParsesJsonCorrectly(t *testing.T) {
|
||||
defer registry.UnsubscribeForBackupsCompletions()
|
||||
|
||||
receivedBackups := make(chan uuid.UUID, 2)
|
||||
handler := func(nodeID string, backupID uuid.UUID) {
|
||||
handler := func(nodeID uuid.UUID, backupID uuid.UUID) {
|
||||
receivedBackups <- backupID
|
||||
}
|
||||
|
||||
@@ -719,10 +949,10 @@ func Test_SubscribeForBackupsCompletions_ParsesJsonCorrectly(t *testing.T) {
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.PublishBackupCompletion(node.ID.String(), backupID1)
|
||||
err = registry.PublishBackupCompletion(node.ID, backupID1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.PublishBackupCompletion(node.ID.String(), backupID2)
|
||||
err = registry.PublishBackupCompletion(node.ID, backupID2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
received1 := <-receivedBackups
|
||||
@@ -739,7 +969,7 @@ func Test_SubscribeForBackupsCompletions_HandlesInvalidJson(t *testing.T) {
|
||||
defer registry.UnsubscribeForBackupsCompletions()
|
||||
|
||||
receivedBackupID := make(chan uuid.UUID, 1)
|
||||
handler := func(nodeID string, backupID uuid.UUID) {
|
||||
handler := func(nodeID uuid.UUID, backupID uuid.UUID) {
|
||||
receivedBackupID <- backupID
|
||||
}
|
||||
|
||||
@@ -767,7 +997,7 @@ func Test_UnsubscribeForBackupsCompletions_StopsReceivingMessages(t *testing.T)
|
||||
backupID2 := uuid.New()
|
||||
|
||||
receivedBackupID := make(chan uuid.UUID, 2)
|
||||
handler := func(nodeID string, backupID uuid.UUID) {
|
||||
handler := func(nodeID uuid.UUID, backupID uuid.UUID) {
|
||||
receivedBackupID <- backupID
|
||||
}
|
||||
|
||||
@@ -776,7 +1006,7 @@ func Test_UnsubscribeForBackupsCompletions_StopsReceivingMessages(t *testing.T)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.PublishBackupCompletion(node.ID.String(), backupID1)
|
||||
err = registry.PublishBackupCompletion(node.ID, backupID1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
received := <-receivedBackupID
|
||||
@@ -787,7 +1017,7 @@ func Test_UnsubscribeForBackupsCompletions_StopsReceivingMessages(t *testing.T)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.PublishBackupCompletion(node.ID.String(), backupID2)
|
||||
err = registry.PublishBackupCompletion(node.ID, backupID2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -802,7 +1032,7 @@ func Test_SubscribeForBackupsCompletions_WhenAlreadySubscribed_ReturnsError(t *t
|
||||
registry := createTestRegistry()
|
||||
defer registry.UnsubscribeForBackupsCompletions()
|
||||
|
||||
handler := func(nodeID string, backupID uuid.UUID) {}
|
||||
handler := func(nodeID uuid.UUID, backupID uuid.UUID) {}
|
||||
|
||||
err := registry.SubscribeForBackupsCompletions(handler)
|
||||
assert.NoError(t, err)
|
||||
@@ -834,9 +1064,9 @@ func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) {
|
||||
receivedBackups2 := make(chan uuid.UUID, 3)
|
||||
receivedBackups3 := make(chan uuid.UUID, 3)
|
||||
|
||||
handler1 := func(nodeID string, backupID uuid.UUID) { receivedBackups1 <- backupID }
|
||||
handler2 := func(nodeID string, backupID uuid.UUID) { receivedBackups2 <- backupID }
|
||||
handler3 := func(nodeID string, backupID uuid.UUID) { receivedBackups3 <- backupID }
|
||||
handler1 := func(nodeID uuid.UUID, backupID uuid.UUID) { receivedBackups1 <- backupID }
|
||||
handler2 := func(nodeID uuid.UUID, backupID uuid.UUID) { receivedBackups2 <- backupID }
|
||||
handler3 := func(nodeID uuid.UUID, backupID uuid.UUID) { receivedBackups3 <- backupID }
|
||||
|
||||
err := registry1.SubscribeForBackupsCompletions(handler1)
|
||||
assert.NoError(t, err)
|
||||
@@ -850,13 +1080,13 @@ func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
publishRegistry := createTestRegistry()
|
||||
err = publishRegistry.PublishBackupCompletion(node1.ID.String(), backupID1)
|
||||
err = publishRegistry.PublishBackupCompletion(node1.ID, backupID1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = publishRegistry.PublishBackupCompletion(node2.ID.String(), backupID2)
|
||||
err = publishRegistry.PublishBackupCompletion(node2.ID, backupID2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = publishRegistry.PublishBackupCompletion(node3.ID.String(), backupID3)
|
||||
err = publishRegistry.PublishBackupCompletion(node3.ID, backupID3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
receivedAll1 := []uuid.UUID{}
|
||||
|
||||
@@ -2,89 +2,105 @@ package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"databasus-backend/internal/config"
|
||||
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/storages"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/period"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
)
|
||||
|
||||
const (
|
||||
schedulerStartupDelay = 1 * time.Minute
|
||||
schedulerTickerInterval = 1 * time.Minute
|
||||
schedulerHealthcheckThreshold = 5 * time.Minute
|
||||
)
|
||||
|
||||
type BackupsScheduler struct {
|
||||
backupRepository *backups_core.BackupRepository
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
storageService *storages.StorageService
|
||||
backupCancelManager *backups_cancellation.BackupCancelManager
|
||||
nodesRegistry *BackupNodesRegistry
|
||||
taskCancelManager *task_cancellation.TaskCancelManager
|
||||
backupNodesRegistry *BackupNodesRegistry
|
||||
|
||||
lastBackupTime time.Time
|
||||
logger *slog.Logger
|
||||
|
||||
backupToNodeRelations map[uuid.UUID]BackupToNodeRelation
|
||||
backuperNode *BackuperNode
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) Run(ctx context.Context) {
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
|
||||
if config.GetEnv().IsManyNodesMode {
|
||||
// wait other nodes to start
|
||||
time.Sleep(1 * time.Minute)
|
||||
}
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
|
||||
if err := s.failBackupsInProgress(); err != nil {
|
||||
s.logger.Error("Failed to fail backups in progress", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
|
||||
if err := s.nodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted); err != nil {
|
||||
s.logger.Error("Failed to subscribe to backup completions", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := s.nodesRegistry.UnsubscribeForBackupsCompletions(); err != nil {
|
||||
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
|
||||
if config.GetEnv().IsManyNodesMode {
|
||||
// wait other nodes to start
|
||||
time.Sleep(schedulerStartupDelay)
|
||||
}
|
||||
}()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if err := s.failBackupsInProgress(); err != nil {
|
||||
s.logger.Error("Failed to fail backups in progress", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
err := s.backupNodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to subscribe to backup completions", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
defer func() {
|
||||
if err := s.backupNodesRegistry.UnsubscribeForBackupsCompletions(); err != nil {
|
||||
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.cleanOldBackups(); err != nil {
|
||||
s.logger.Error("Failed to clean old backups", "error", err)
|
||||
}
|
||||
|
||||
if err := s.checkDeadNodesAndFailBackups(); err != nil {
|
||||
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
|
||||
}
|
||||
|
||||
if err := s.runPendingBackups(); err != nil {
|
||||
s.logger.Error("Failed to run pending backups", "error", err)
|
||||
}
|
||||
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(schedulerTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.checkDeadNodesAndFailBackups(); err != nil {
|
||||
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
|
||||
}
|
||||
|
||||
if err := s.runPendingBackups(); err != nil {
|
||||
s.logger.Error("Failed to run pending backups", "error", err)
|
||||
}
|
||||
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) IsSchedulerRunning() bool {
|
||||
// if last backup time is more than 5 minutes ago, return false
|
||||
return s.lastBackupTime.After(time.Now().UTC().Add(-5 * time.Minute))
|
||||
return s.lastBackupTime.After(time.Now().UTC().Add(-schedulerHealthcheckThreshold))
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) failBackupsInProgress() error {
|
||||
@@ -93,12 +109,10 @@ func (s *BackupsScheduler) failBackupsInProgress() error {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Println("Backups in progress", len(backupsInProgress))
|
||||
|
||||
for _, backup := range backupsInProgress {
|
||||
if err := s.backupCancelManager.CancelBackup(backup.ID); err != nil {
|
||||
if err := s.taskCancelManager.CancelTask(backup.ID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to cancel backup via context manager",
|
||||
"Failed to cancel backup via task cancel manager",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
@@ -144,6 +158,33 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
|
||||
return
|
||||
}
|
||||
|
||||
// Check for existing in-progress backups
|
||||
inProgressBackups, err := s.backupRepository.FindByDatabaseIdAndStatus(
|
||||
databaseID,
|
||||
backups_core.BackupStatusInProgress,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to check for in-progress backups",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if len(inProgressBackups) > 0 {
|
||||
s.logger.Warn(
|
||||
"Backup already in progress for database, skipping new backup",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"existingBackupId",
|
||||
inProgressBackups[0].ID,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
leastBusyNodeID, err := s.calculateLeastBusyNode()
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
@@ -175,7 +216,7 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.nodesRegistry.IncrementBackupsInProgress(leastBusyNodeID.String()); err != nil {
|
||||
if err := s.backupNodesRegistry.IncrementBackupsInProgress(*leastBusyNodeID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to increment backups in progress",
|
||||
"nodeId",
|
||||
@@ -188,7 +229,7 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.nodesRegistry.AssignBackupToNode(leastBusyNodeID.String(), backup.ID, isCallNotifier); err != nil {
|
||||
if err := s.backupNodesRegistry.AssignBackupToNode(*leastBusyNodeID, backup.ID, isCallNotifier); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to submit backup",
|
||||
"nodeId",
|
||||
@@ -198,7 +239,7 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
if decrementErr := s.nodesRegistry.DecrementBackupsInProgress(leastBusyNodeID.String()); decrementErr != nil {
|
||||
if decrementErr := s.backupNodesRegistry.DecrementBackupsInProgress(*leastBusyNodeID); decrementErr != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement backups in progress after submit failure",
|
||||
"nodeId",
|
||||
@@ -244,6 +285,10 @@ func (s *BackupsScheduler) GetRemainedBackupTryCount(lastBackup *backups_core.Ba
|
||||
return 0
|
||||
}
|
||||
|
||||
if lastBackup.IsSkipRetry {
|
||||
return 0
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(lastBackup.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
@@ -276,74 +321,6 @@ func (s *BackupsScheduler) GetRemainedBackupTryCount(lastBackup *backups_core.Ba
|
||||
return maxFailedTriesCount - len(lastFailedBackups)
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) cleanOldBackups() error {
|
||||
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
backupStorePeriod := backupConfig.StorePeriod
|
||||
|
||||
if backupStorePeriod == period.PeriodForever {
|
||||
continue
|
||||
}
|
||||
|
||||
storeDuration := backupStorePeriod.ToDuration()
|
||||
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
|
||||
|
||||
oldBackups, err := s.backupRepository.FindBackupsBeforeDate(
|
||||
backupConfig.DatabaseID,
|
||||
dateBeforeBackupsShouldBeDeleted,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to find old backups for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, backup := range oldBackups {
|
||||
storage, err := s.storageService.GetStorageByID(backup.StorageID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to get storage by ID",
|
||||
"storageId",
|
||||
backup.StorageID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
err = storage.DeleteFile(encryptor, backup.ID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to delete backup file", "backupId", backup.ID, "error", err)
|
||||
}
|
||||
|
||||
if err := s.backupRepository.DeleteByID(backup.ID); err != nil {
|
||||
s.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
s.logger.Info(
|
||||
"Deleted old backup",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) runPendingBackups() error {
|
||||
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
@@ -393,7 +370,7 @@ func (s *BackupsScheduler) runPendingBackups() error {
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
|
||||
nodes, err := s.nodesRegistry.GetAvailableNodes()
|
||||
nodes, err := s.backupNodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get available nodes: %w", err)
|
||||
}
|
||||
@@ -402,7 +379,7 @@ func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
|
||||
return nil, fmt.Errorf("no nodes available")
|
||||
}
|
||||
|
||||
stats, err := s.nodesRegistry.GetBackupNodesStats()
|
||||
stats, err := s.backupNodesRegistry.GetBackupNodesStats()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get backup nodes stats: %w", err)
|
||||
}
|
||||
@@ -415,14 +392,9 @@ func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
|
||||
var bestNode *BackupNode
|
||||
var bestScore float64 = -1
|
||||
|
||||
now := time.Now().UTC()
|
||||
for i := range nodes {
|
||||
node := &nodes[i]
|
||||
|
||||
if now.Sub(node.LastHeartbeat) > 2*time.Minute {
|
||||
continue
|
||||
}
|
||||
|
||||
activeBackups := statsMap[node.ID]
|
||||
|
||||
var score float64
|
||||
@@ -445,16 +417,11 @@ func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
|
||||
return &bestNode.ID, nil
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUID) {
|
||||
nodeID, err := uuid.Parse(nodeIDStr)
|
||||
func (s *BackupsScheduler) onBackupCompleted(nodeID uuid.UUID, backupID uuid.UUID) {
|
||||
// Verify this task is actually a backup (registry contains multiple task types)
|
||||
_, err := s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to parse node ID from completion message",
|
||||
"nodeId",
|
||||
nodeIDStr,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
// Not a backup task, ignore it
|
||||
return
|
||||
}
|
||||
|
||||
@@ -498,7 +465,7 @@ func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUI
|
||||
s.backupToNodeRelations[nodeID] = relation
|
||||
}
|
||||
|
||||
if err := s.nodesRegistry.DecrementBackupsInProgress(nodeIDStr); err != nil {
|
||||
if err := s.backupNodesRegistry.DecrementBackupsInProgress(nodeID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement backups in progress",
|
||||
"nodeId",
|
||||
@@ -512,18 +479,14 @@ func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUI
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) checkDeadNodesAndFailBackups() error {
|
||||
nodes, err := s.nodesRegistry.GetAvailableNodes()
|
||||
nodes, err := s.backupNodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get available nodes: %w", err)
|
||||
}
|
||||
|
||||
aliveNodeIDs := make(map[uuid.UUID]bool)
|
||||
now := time.Now().UTC()
|
||||
|
||||
for _, node := range nodes {
|
||||
if now.Sub(node.LastHeartbeat) <= 2*time.Minute {
|
||||
aliveNodeIDs[node.ID] = true
|
||||
}
|
||||
aliveNodeIDs[node.ID] = true
|
||||
}
|
||||
|
||||
for nodeID, relation := range s.backupToNodeRelations {
|
||||
@@ -572,7 +535,7 @@ func (s *BackupsScheduler) checkDeadNodesAndFailBackups() error {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := s.nodesRegistry.DecrementBackupsInProgress(nodeID.String()); err != nil {
|
||||
if err := s.backupNodesRegistry.DecrementBackupsInProgress(nodeID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement backups in progress for dead node",
|
||||
"nodeId",
|
||||
|
||||
@@ -42,8 +42,8 @@ func Test_RunPendingBackups_WhenLastBackupWasYesterday_CreatesNewBackup(t *testi
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
@@ -111,8 +111,8 @@ func Test_RunPendingBackups_WhenLastBackupWasRecentlyCompleted_SkipsBackup(t *te
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
@@ -179,8 +179,8 @@ func Test_RunPendingBackups_WhenLastBackupFailedAndRetriesDisabled_SkipsBackup(t
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
@@ -251,8 +251,8 @@ func Test_RunPendingBackups_WhenLastBackupFailedAndRetriesEnabled_CreatesNewBack
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
@@ -324,8 +324,8 @@ func Test_RunPendingBackups_WhenFailedBackupsExceedMaxRetries_SkipsBackup(t *tes
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
@@ -396,8 +396,8 @@ func Test_RunPendingBackups_WhenBackupsDisabled_SkipsBackup(t *testing.T) {
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
@@ -449,6 +449,8 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
var mockNodeID uuid.UUID
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
@@ -457,9 +459,15 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
// Clean up mock node
|
||||
if mockNodeID != uuid.Nil {
|
||||
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: mockNodeID})
|
||||
}
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
@@ -479,7 +487,7 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Register mock node without subscribing to backups (simulates node crash after registration)
|
||||
mockNodeID := uuid.New()
|
||||
mockNodeID = uuid.New()
|
||||
err = CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC())
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -493,7 +501,7 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
|
||||
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status)
|
||||
|
||||
// Verify Valkey counter was incremented when backup was assigned
|
||||
stats, err := nodesRegistry.GetBackupNodesStats()
|
||||
stats, err := backupNodesRegistry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
foundStat := false
|
||||
for _, stat := range stats {
|
||||
@@ -523,7 +531,7 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
|
||||
assert.Contains(t, *backups[0].FailMessage, "node unavailability")
|
||||
|
||||
// Verify Valkey counter was decremented after backup failed
|
||||
stats, err = nodesRegistry.GetBackupNodesStats()
|
||||
stats, err = backupNodesRegistry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
for _, stat := range stats {
|
||||
if stat.ID == mockNodeID {
|
||||
@@ -531,11 +539,109 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
|
||||
}
|
||||
}
|
||||
|
||||
// Node info should still exist in registry (not removed by checkDeadNodesAndFailBackups)
|
||||
node, err := GetNodeFromRegistry(mockNodeID)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_OnBackupCompleted_WhenTaskIsNotBackup_SkipsProcessing(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
var mockNodeID uuid.UUID
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
// Clean up mock node
|
||||
if mockNodeID != uuid.Nil {
|
||||
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: mockNodeID})
|
||||
}
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, node)
|
||||
assert.Equal(t, mockNodeID, node.ID)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Register mock node
|
||||
mockNodeID = uuid.New()
|
||||
err = CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC())
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Start a backup and assign it to the node
|
||||
GetBackupsScheduler().StartBackup(database.ID, false)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1)
|
||||
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status)
|
||||
|
||||
// Get initial state of the registry
|
||||
initialStats, err := backupNodesRegistry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
var initialActiveTasks int
|
||||
for _, stat := range initialStats {
|
||||
if stat.ID == mockNodeID {
|
||||
initialActiveTasks = stat.ActiveBackups
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, initialActiveTasks, "Should have 1 active task")
|
||||
|
||||
// Call onBackupCompleted with a random UUID (not a backup ID)
|
||||
nonBackupTaskID := uuid.New()
|
||||
GetBackupsScheduler().onBackupCompleted(mockNodeID, nonBackupTaskID)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify: Active tasks counter should remain the same (not decremented)
|
||||
stats, err := backupNodesRegistry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
for _, stat := range stats {
|
||||
if stat.ID == mockNodeID {
|
||||
assert.Equal(t, initialActiveTasks, stat.ActiveBackups,
|
||||
"Active tasks should not change for non-backup task")
|
||||
}
|
||||
}
|
||||
|
||||
// Verify: backup should still be in progress (not modified)
|
||||
backups, err = backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1)
|
||||
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status,
|
||||
"Backup status should not change for non-backup task completion")
|
||||
|
||||
// Verify: backupToNodeRelations should still contain the node
|
||||
scheduler := GetBackupsScheduler()
|
||||
_, exists := scheduler.backupToNodeRelations[mockNodeID]
|
||||
assert.True(t, exists, "Node should still be in backupToNodeRelations")
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
@@ -549,6 +655,14 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
|
||||
node3ID := uuid.New()
|
||||
now := time.Now().UTC()
|
||||
|
||||
defer func() {
|
||||
// Clean up all mock nodes
|
||||
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node1ID})
|
||||
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node2ID})
|
||||
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node3ID})
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
err := CreateMockNodeInRegistry(node1ID, 100, now)
|
||||
assert.NoError(t, err)
|
||||
err = CreateMockNodeInRegistry(node2ID, 100, now)
|
||||
@@ -557,17 +671,17 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
for range 5 {
|
||||
err = nodesRegistry.IncrementBackupsInProgress(node1ID.String())
|
||||
err = backupNodesRegistry.IncrementBackupsInProgress(node1ID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
for range 2 {
|
||||
err = nodesRegistry.IncrementBackupsInProgress(node2ID.String())
|
||||
err = backupNodesRegistry.IncrementBackupsInProgress(node2ID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
for range 8 {
|
||||
err = nodesRegistry.IncrementBackupsInProgress(node3ID.String())
|
||||
err = backupNodesRegistry.IncrementBackupsInProgress(node3ID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -584,17 +698,24 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
|
||||
node50MBsID := uuid.New()
|
||||
now := time.Now().UTC()
|
||||
|
||||
defer func() {
|
||||
// Clean up all mock nodes
|
||||
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node100MBsID})
|
||||
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node50MBsID})
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
err := CreateMockNodeInRegistry(node100MBsID, 100, now)
|
||||
assert.NoError(t, err)
|
||||
err = CreateMockNodeInRegistry(node50MBsID, 50, now)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for range 10 {
|
||||
err = nodesRegistry.IncrementBackupsInProgress(node100MBsID.String())
|
||||
err = backupNodesRegistry.IncrementBackupsInProgress(node100MBsID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
err = nodesRegistry.IncrementBackupsInProgress(node50MBsID.String())
|
||||
err = backupNodesRegistry.IncrementBackupsInProgress(node50MBsID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
leastBusyNodeID, err := GetBackupsScheduler().calculateLeastBusyNode()
|
||||
@@ -622,9 +743,11 @@ func Test_FailBackupsInProgress_WhenSchedulerStarts_CancelsBackupsAndUpdatesStat
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
@@ -707,3 +830,492 @@ func Test_FailBackupsInProgress_WhenSchedulerStarts_CancelsBackupsAndUpdatesStat
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
// Start scheduler so it can handle task completions
|
||||
scheduler := CreateTestScheduler()
|
||||
schedulerCancel := StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Get initial active task count
|
||||
stats, err := backupNodesRegistry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
var initialActiveTasks int
|
||||
for _, stat := range stats {
|
||||
if stat.ID == backuperNode.nodeID {
|
||||
initialActiveTasks = stat.ActiveBackups
|
||||
break
|
||||
}
|
||||
}
|
||||
t.Logf("Initial active tasks: %d", initialActiveTasks)
|
||||
|
||||
// Start backup
|
||||
scheduler.StartBackup(database.ID, false)
|
||||
|
||||
// Wait for backup to complete
|
||||
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
|
||||
|
||||
// Verify backup was created and completed
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backups[0].Status)
|
||||
|
||||
// Wait for active task count to decrease
|
||||
decreased := WaitForActiveTasksDecrease(
|
||||
t,
|
||||
backuperNode.nodeID,
|
||||
initialActiveTasks+1,
|
||||
10*time.Second,
|
||||
)
|
||||
assert.True(t, decreased, "Active task count should have decreased after backup completion")
|
||||
|
||||
// Verify final active task count equals initial count
|
||||
finalStats, err := backupNodesRegistry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
for _, stat := range finalStats {
|
||||
if stat.ID == backuperNode.nodeID {
|
||||
t.Logf("Final active tasks: %d", stat.ActiveBackups)
|
||||
assert.Equal(t, initialActiveTasks, stat.ActiveBackups,
|
||||
"Active task count should return to initial value after backup completion")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
// Start scheduler so it can handle task completions
|
||||
scheduler := CreateTestScheduler()
|
||||
schedulerCancel := StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Set wrong password to cause backup failure
|
||||
// We need to bypass service layer validation which would fail on connection test
|
||||
database.Postgresql.Password = "intentionally_wrong_password"
|
||||
dbRepo := &databases.DatabaseRepository{}
|
||||
_, err := dbRepo.Save(database)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Get initial active task count
|
||||
stats, err := backupNodesRegistry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
var initialActiveTasks int
|
||||
for _, stat := range stats {
|
||||
if stat.ID == backuperNode.nodeID {
|
||||
initialActiveTasks = stat.ActiveBackups
|
||||
break
|
||||
}
|
||||
}
|
||||
t.Logf("Initial active tasks: %d", initialActiveTasks)
|
||||
|
||||
// Start backup
|
||||
scheduler.StartBackup(database.ID, false)
|
||||
|
||||
// Wait for backup to fail
|
||||
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
|
||||
|
||||
// Verify backup was created and failed
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1)
|
||||
assert.Equal(t, backups_core.BackupStatusFailed, backups[0].Status)
|
||||
assert.NotNil(t, backups[0].FailMessage)
|
||||
if backups[0].FailMessage != nil {
|
||||
t.Logf("Backup failed with message: %s", *backups[0].FailMessage)
|
||||
}
|
||||
|
||||
// Wait for active task count to decrease
|
||||
decreased := WaitForActiveTasksDecrease(
|
||||
t,
|
||||
backuperNode.nodeID,
|
||||
initialActiveTasks+1,
|
||||
10*time.Second,
|
||||
)
|
||||
assert.True(t, decreased, "Active task count should have decreased after backup failure")
|
||||
|
||||
// Verify final active task count equals initial count
|
||||
finalStats, err := backupNodesRegistry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
for _, stat := range finalStats {
|
||||
if stat.ID == backuperNode.nodeID {
|
||||
t.Logf("Final active tasks: %d", stat.ActiveBackups)
|
||||
assert.Equal(t, initialActiveTasks, stat.ActiveBackups,
|
||||
"Active task count should return to initial value after backup failure")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_StartBackup_WhenBackupAlreadyInProgress_SkipsNewBackup(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create an in-progress backup manually
|
||||
inProgressBackup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
BackupSizeMb: 0,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err = backupRepository.Save(inProgressBackup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Try to start a new backup - should be skipped
|
||||
GetBackupsScheduler().StartBackup(database.ID, false)
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Verify only 1 backup exists (the original in-progress one)
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1)
|
||||
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status)
|
||||
assert.Equal(t, inProgressBackup.ID, backups[0].ID)
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_RunPendingBackups_WhenLastBackupFailedWithIsSkipRetry_SkipsBackupEvenWithRetriesEnabled(
|
||||
t *testing.T,
|
||||
) {
|
||||
cache_utils.ClearAllCache()
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Enable backups with retries enabled and high retry count
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
backupConfig.IsRetryIfFailed = true
|
||||
backupConfig.MaxFailedTriesCount = 5
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create a failed backup with IsSkipRetry set to true
|
||||
failMessage := "backup failed due to size limit exceeded"
|
||||
backupRepository.Save(&backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: backups_core.BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
IsSkipRetry: true,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
})
|
||||
|
||||
// Verify GetRemainedBackupTryCount returns 0 even though retries are enabled
|
||||
lastBackup, err := backupRepository.FindLastByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, lastBackup)
|
||||
|
||||
remainedTries := GetBackupsScheduler().GetRemainedBackupTryCount(lastBackup)
|
||||
assert.Equal(t, 0, remainedTries, "Should return 0 tries when IsSkipRetry is true")
|
||||
|
||||
// Run the scheduler
|
||||
GetBackupsScheduler().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify no new backup was created (still only 1 backup exists)
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1, "No retry should be attempted when IsSkipRetry is true")
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_StartBackup_When2BackupsStartedForDifferentDatabases_BothUseCasesAreCalled(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
// Create mock tracking use case
|
||||
mockUseCase := NewMockTrackingBackupUsecase()
|
||||
|
||||
// Create BackuperNode with mock use case
|
||||
backuperNode := CreateTestBackuperNodeWithUseCase(mockUseCase)
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
// Create scheduler
|
||||
scheduler := CreateTestScheduler()
|
||||
schedulerCancel := StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
// Setup test data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
|
||||
// Create 2 separate databases
|
||||
database1 := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
database2 := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
// Cleanup backups for database1
|
||||
backups1, _ := backupRepository.FindByDatabaseID(database1.ID)
|
||||
for _, backup := range backups1 {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
// Cleanup backups for database2
|
||||
backups2, _ := backupRepository.FindByDatabaseID(database2.ID)
|
||||
for _, backup := range backups2 {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database1)
|
||||
databases.RemoveTestDatabase(database2)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Enable backups for database1
|
||||
backupConfig1, err := backups_config.GetBackupConfigService().
|
||||
GetBackupConfigByDbId(database1.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig1.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig1.IsBackupsEnabled = true
|
||||
backupConfig1.StorePeriod = period.PeriodWeek
|
||||
backupConfig1.Storage = storage
|
||||
backupConfig1.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Enable backups for database2
|
||||
backupConfig2, err := backups_config.GetBackupConfigService().
|
||||
GetBackupConfigByDbId(database2.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backupConfig2.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig2.IsBackupsEnabled = true
|
||||
backupConfig2.StorePeriod = period.PeriodWeek
|
||||
backupConfig2.Storage = storage
|
||||
backupConfig2.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Start 2 backups simultaneously
|
||||
t.Log("Starting backup for database1")
|
||||
scheduler.StartBackup(database1.ID, false)
|
||||
|
||||
t.Log("Starting backup for database2")
|
||||
scheduler.StartBackup(database2.ID, false)
|
||||
|
||||
// Wait up to 10 seconds for both backups to complete
|
||||
t.Log("Waiting for both backups to complete...")
|
||||
|
||||
success := assert.Eventually(t, func() bool {
|
||||
callCount := mockUseCase.GetCallCount()
|
||||
t.Logf("Current call count: %d/2", callCount)
|
||||
return callCount == 2
|
||||
}, 10*time.Second, 200*time.Millisecond, "Both use cases should be called within 10 seconds")
|
||||
|
||||
if !success {
|
||||
t.Logf("Test failed: Only %d out of 2 use cases were called", mockUseCase.GetCallCount())
|
||||
}
|
||||
|
||||
// Verify both backup IDs were received
|
||||
calledBackupIDs := mockUseCase.GetCalledBackupIDs()
|
||||
t.Logf("Called backup IDs: %v", calledBackupIDs)
|
||||
assert.Len(t, calledBackupIDs, 2, "Both backup IDs should be tracked")
|
||||
|
||||
// Verify both backups exist in repository and are completed
|
||||
backups1, err := backupRepository.FindByDatabaseID(database1.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups1, 1, "Database1 should have 1 backup")
|
||||
if len(backups1) > 0 {
|
||||
t.Logf("Database1 backup status: %s", backups1[0].Status)
|
||||
}
|
||||
|
||||
backups2, err := backupRepository.FindByDatabaseID(database2.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups2, 1, "Database2 should have 1 backup")
|
||||
if len(backups2) > 0 {
|
||||
t.Logf("Database2 backup status: %s", backups2[0].Status)
|
||||
}
|
||||
|
||||
// Verify both backups completed successfully
|
||||
if len(backups1) > 0 {
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backups1[0].Status,
|
||||
"Database1 backup should be completed")
|
||||
}
|
||||
|
||||
if len(backups2) > 0 {
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backups2[0].Status,
|
||||
"Database2 backup should be completed")
|
||||
}
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package backuping
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -35,19 +37,56 @@ func CreateTestRouter() *gin.Engine {
|
||||
|
||||
func CreateTestBackuperNode() *BackuperNode {
|
||||
return &BackuperNode{
|
||||
databases.GetDatabaseService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
backupCancelManager,
|
||||
nodesRegistry,
|
||||
logger.GetLogger(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
uuid.New(),
|
||||
time.Time{},
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
workspaceService: workspaces_services.GetWorkspaceService(),
|
||||
backupRepository: backupRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
notificationSender: notifiers.GetNotifierService(),
|
||||
backupCancelManager: taskCancelManager,
|
||||
backupNodesRegistry: backupNodesRegistry,
|
||||
logger: logger.GetLogger(),
|
||||
createBackupUseCase: usecases.GetCreateBackupUsecase(),
|
||||
nodeID: uuid.New(),
|
||||
lastHeartbeat: time.Time{},
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateTestBackuperNodeWithUseCase(useCase backups_core.CreateBackupUsecase) *BackuperNode {
|
||||
return &BackuperNode{
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
workspaceService: workspaces_services.GetWorkspaceService(),
|
||||
backupRepository: backupRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
notificationSender: notifiers.GetNotifierService(),
|
||||
backupCancelManager: taskCancelManager,
|
||||
backupNodesRegistry: backupNodesRegistry,
|
||||
logger: logger.GetLogger(),
|
||||
createBackupUseCase: useCase,
|
||||
nodeID: uuid.New(),
|
||||
lastHeartbeat: time.Time{},
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateTestScheduler() *BackupsScheduler {
|
||||
return &BackupsScheduler{
|
||||
backupRepository: backupRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
taskCancelManager: taskCancelManager,
|
||||
backupNodesRegistry: backupNodesRegistry,
|
||||
lastBackupTime: time.Now().UTC(),
|
||||
logger: logger.GetLogger(),
|
||||
backupToNodeRelations: make(map[uuid.UUID]BackupToNodeRelation),
|
||||
backuperNode: CreateTestBackuperNode(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,7 +152,7 @@ func StartBackuperNodeForTest(t *testing.T, backuperNode *BackuperNode) context.
|
||||
// Poll registry for node presence instead of fixed sleep
|
||||
deadline := time.Now().UTC().Add(5 * time.Second)
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
nodes, err := nodesRegistry.GetAvailableNodes()
|
||||
nodes, err := backupNodesRegistry.GetAvailableNodes()
|
||||
if err == nil {
|
||||
for _, node := range nodes {
|
||||
if node.ID == backuperNode.nodeID {
|
||||
@@ -138,6 +177,34 @@ func StartBackuperNodeForTest(t *testing.T, backuperNode *BackuperNode) context.
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartSchedulerForTest starts the BackupsScheduler in a goroutine for testing.
|
||||
// The scheduler subscribes to task completions and manages backup lifecycle.
|
||||
// Returns a context cancel function that should be deferred to stop the scheduler.
|
||||
func StartSchedulerForTest(t *testing.T, scheduler *BackupsScheduler) context.CancelFunc {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
scheduler.Run(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Give scheduler time to subscribe to completions
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
t.Log("BackupsScheduler started")
|
||||
|
||||
return func() {
|
||||
cancel()
|
||||
select {
|
||||
case <-done:
|
||||
t.Log("BackupsScheduler stopped gracefully")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Log("BackupsScheduler stop timeout")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StopBackuperNodeForTest stops the BackuperNode by canceling its context.
|
||||
// It waits for the node to unregister from the registry.
|
||||
func StopBackuperNodeForTest(t *testing.T, cancel context.CancelFunc, backuperNode *BackuperNode) {
|
||||
@@ -146,7 +213,7 @@ func StopBackuperNodeForTest(t *testing.T, cancel context.CancelFunc, backuperNo
|
||||
// Wait for node to unregister from registry
|
||||
deadline := time.Now().UTC().Add(2 * time.Second)
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
nodes, err := nodesRegistry.GetAvailableNodes()
|
||||
nodes, err := backupNodesRegistry.GetAvailableNodes()
|
||||
if err == nil {
|
||||
found := false
|
||||
for _, node := range nodes {
|
||||
@@ -173,7 +240,7 @@ func CreateMockNodeInRegistry(nodeID uuid.UUID, throughputMBs int, lastHeartbeat
|
||||
LastHeartbeat: lastHeartbeat,
|
||||
}
|
||||
|
||||
return nodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
|
||||
return backupNodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
|
||||
}
|
||||
|
||||
func UpdateNodeHeartbeatDirectly(
|
||||
@@ -187,11 +254,11 @@ func UpdateNodeHeartbeatDirectly(
|
||||
LastHeartbeat: lastHeartbeat,
|
||||
}
|
||||
|
||||
return nodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
|
||||
return backupNodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
|
||||
}
|
||||
|
||||
func GetNodeFromRegistry(nodeID uuid.UUID) (*BackupNode, error) {
|
||||
nodes, err := nodesRegistry.GetAvailableNodes()
|
||||
nodes, err := backupNodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -204,3 +271,48 @@ func GetNodeFromRegistry(nodeID uuid.UUID) (*BackupNode, error) {
|
||||
|
||||
return nil, fmt.Errorf("node not found")
|
||||
}
|
||||
|
||||
// WaitForActiveTasksDecrease waits for the active task count to decrease below the initial count.
|
||||
// It polls the registry every 500ms until the count decreases or the timeout is reached.
|
||||
// Returns true if the count decreased, false if timeout was reached.
|
||||
func WaitForActiveTasksDecrease(
|
||||
t *testing.T,
|
||||
nodeID uuid.UUID,
|
||||
initialCount int,
|
||||
timeout time.Duration,
|
||||
) bool {
|
||||
deadline := time.Now().UTC().Add(timeout)
|
||||
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
stats, err := backupNodesRegistry.GetBackupNodesStats()
|
||||
if err != nil {
|
||||
t.Logf("WaitForActiveTasksDecrease: error getting node stats: %v", err)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, stat := range stats {
|
||||
if stat.ID == nodeID {
|
||||
t.Logf(
|
||||
"WaitForActiveTasksDecrease: current active tasks = %d (initial = %d)",
|
||||
stat.ActiveBackups,
|
||||
initialCount,
|
||||
)
|
||||
if stat.ActiveBackups < initialCount {
|
||||
t.Logf(
|
||||
"WaitForActiveTasksDecrease: active tasks decreased from %d to %d",
|
||||
initialCount,
|
||||
stat.ActiveBackups,
|
||||
)
|
||||
return true
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Logf("WaitForActiveTasksDecrease: timeout waiting for active tasks to decrease")
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
package backups_cancellation
|
||||
|
||||
import (
|
||||
"context"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const backupCancelChannel = "backup:cancel"
|
||||
|
||||
type BackupCancelManager struct {
|
||||
mu sync.RWMutex
|
||||
cancelFuncs map[uuid.UUID]context.CancelFunc
|
||||
pubsub *cache_utils.PubSubManager
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (m *BackupCancelManager) StartSubscription() {
|
||||
ctx := context.Background()
|
||||
|
||||
handler := func(message string) {
|
||||
backupID, err := uuid.Parse(message)
|
||||
if err != nil {
|
||||
m.logger.Error("Invalid backup ID in cancel message", "message", message, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
cancelFunc, exists := m.cancelFuncs[backupID]
|
||||
if exists {
|
||||
cancelFunc()
|
||||
delete(m.cancelFuncs, backupID)
|
||||
m.logger.Info("Cancelled backup via Pub/Sub", "backupID", backupID)
|
||||
}
|
||||
}
|
||||
|
||||
err := m.pubsub.Subscribe(ctx, backupCancelChannel, handler)
|
||||
if err != nil {
|
||||
m.logger.Error("Failed to subscribe to backup cancel channel", "error", err)
|
||||
} else {
|
||||
m.logger.Info("Successfully subscribed to backup cancel channel")
|
||||
}
|
||||
}
|
||||
|
||||
func (m *BackupCancelManager) RegisterBackup(backupID uuid.UUID, cancelFunc context.CancelFunc) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.cancelFuncs[backupID] = cancelFunc
|
||||
m.logger.Debug("Registered backup", "backupID", backupID)
|
||||
}
|
||||
|
||||
func (m *BackupCancelManager) CancelBackup(backupID uuid.UUID) error {
|
||||
ctx := context.Background()
|
||||
|
||||
err := m.pubsub.Publish(ctx, backupCancelChannel, backupID.String())
|
||||
if err != nil {
|
||||
m.logger.Error("Failed to publish cancel message", "backupID", backupID, "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
m.logger.Info("Published backup cancel message", "backupID", backupID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *BackupCancelManager) UnregisterBackup(backupID uuid.UUID) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.cancelFuncs, backupID)
|
||||
m.logger.Debug("Unregistered backup", "backupID", backupID)
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package backups_cancellation
|
||||
|
||||
import (
|
||||
"context"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/logger"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var backupCancelManager = &BackupCancelManager{
|
||||
sync.RWMutex{},
|
||||
make(map[uuid.UUID]context.CancelFunc),
|
||||
cache_utils.NewPubSubManager(),
|
||||
logger.GetLogger(),
|
||||
}
|
||||
|
||||
func GetBackupCancelManager() *BackupCancelManager {
|
||||
return backupCancelManager
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
backupCancelManager.StartSubscription()
|
||||
}
|
||||
@@ -913,7 +913,7 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Register a cancellable context for the backup
|
||||
GetBackupService().backupCancelManager.RegisterBackup(backup.ID, func() {})
|
||||
GetBackupService().taskCancelManager.RegisterTask(backup.ID, func() {})
|
||||
|
||||
resp := test_utils.MakePostRequest(
|
||||
t,
|
||||
@@ -1156,7 +1156,7 @@ func createTestDatabase(
|
||||
Type: databases.DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
|
||||
@@ -15,6 +15,7 @@ type Backup struct {
|
||||
|
||||
Status BackupStatus `json:"status" gorm:"column:status;not null"`
|
||||
FailMessage *string `json:"failMessage" gorm:"column:fail_message"`
|
||||
IsSkipRetry bool `json:"isSkipRetry" gorm:"column:is_skip_retry;type:boolean;not null"`
|
||||
|
||||
BackupSizeMb float64 `json:"backupSizeMb" gorm:"column:backup_size_mb;default:0"`
|
||||
|
||||
|
||||
@@ -212,3 +212,36 @@ func (r *BackupRepository) CountByDatabaseID(databaseID uuid.UUID) (int64, error
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *BackupRepository) GetTotalSizeByDatabase(databaseID uuid.UUID) (float64, error) {
|
||||
var totalSize float64
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Model(&Backup{}).
|
||||
Select("COALESCE(SUM(backup_size_mb), 0)").
|
||||
Where("database_id = ? AND status != ?", databaseID, BackupStatusInProgress).
|
||||
Scan(&totalSize).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return totalSize, nil
|
||||
}
|
||||
|
||||
func (r *BackupRepository) FindOldestByDatabaseExcludingInProgress(
|
||||
databaseID uuid.UUID,
|
||||
limit int,
|
||||
) ([]*Backup, error) {
|
||||
var backups []*Backup
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Where("database_id = ? AND status != ?", databaseID, BackupStatusInProgress).
|
||||
Order("created_at ASC").
|
||||
Limit(limit).
|
||||
Find(&backups).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return backups, nil
|
||||
}
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
"databasus-backend/internal/features/backups/backups/usecases"
|
||||
@@ -12,6 +14,7 @@ import (
|
||||
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
@@ -19,25 +22,26 @@ import (
|
||||
|
||||
var backupRepository = &backups_core.BackupRepository{}
|
||||
|
||||
var backupCancelManager = backups_cancellation.GetBackupCancelManager()
|
||||
var taskCancelManager = task_cancellation.GetTaskCancelManager()
|
||||
|
||||
var backupService = &BackupService{
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
backupRepository: backupRepository,
|
||||
notifierService: notifiers.GetNotifierService(),
|
||||
notificationSender: notifiers.GetNotifierService(),
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
secretKeyService: encryption_secrets.GetSecretKeyService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
createBackupUseCase: usecases.GetCreateBackupUsecase(),
|
||||
logger: logger.GetLogger(),
|
||||
backupRemoveListeners: []backups_core.BackupRemoveListener{},
|
||||
workspaceService: workspaces_services.GetWorkspaceService(),
|
||||
auditLogService: audit_logs.GetAuditLogService(),
|
||||
backupCancelManager: backupCancelManager,
|
||||
downloadTokenService: backups_download.GetDownloadTokenService(),
|
||||
backupSchedulerService: backuping.GetBackupsScheduler(),
|
||||
databases.GetDatabaseService(),
|
||||
storages.GetStorageService(),
|
||||
backupRepository,
|
||||
notifiers.GetNotifierService(),
|
||||
notifiers.GetNotifierService(),
|
||||
backups_config.GetBackupConfigService(),
|
||||
encryption_secrets.GetSecretKeyService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
logger.GetLogger(),
|
||||
[]backups_core.BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
audit_logs.GetAuditLogService(),
|
||||
taskCancelManager,
|
||||
backups_download.GetDownloadTokenService(),
|
||||
backuping.GetBackupsScheduler(),
|
||||
backuping.GetBackupCleaner(),
|
||||
}
|
||||
|
||||
var backupController = &BackupController{
|
||||
@@ -52,11 +56,26 @@ func GetBackupController() *BackupController {
|
||||
return backupController
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
backups_config.
|
||||
GetBackupConfigService().
|
||||
SetDatabaseStorageChangeListener(backupService)
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
databases.GetDatabaseService().AddDbRemoveListener(backupService)
|
||||
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
|
||||
func SetupDependencies() {
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
backups_config.
|
||||
GetBackupConfigService().
|
||||
SetDatabaseStorageChangeListener(backupService)
|
||||
|
||||
databases.GetDatabaseService().AddDbRemoveListener(backupService)
|
||||
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,33 +2,49 @@ package backups_download
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DownloadTokenBackgroundService struct {
|
||||
downloadTokenService *DownloadTokenService
|
||||
logger *slog.Logger
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *DownloadTokenBackgroundService) Run(ctx context.Context) {
|
||||
s.logger.Info("Starting download token cleanup background service")
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
s.logger.Info("Starting download token cleanup background service")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
|
||||
s.logger.Error("Failed to clean expired download tokens", "error", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
|
||||
s.logger.Error("Failed to clean expired download tokens", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/logger"
|
||||
@@ -30,8 +33,10 @@ func init() {
|
||||
}
|
||||
|
||||
downloadTokenBackgroundService = &DownloadTokenBackgroundService{
|
||||
downloadTokenService,
|
||||
logger.GetLogger(),
|
||||
downloadTokenService: downloadTokenService,
|
||||
logger: logger.GetLogger(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
"databasus-backend/internal/features/backups/backups/encryption"
|
||||
@@ -18,6 +17,7 @@ import (
|
||||
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
users_models "databasus-backend/internal/features/users/models"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
@@ -43,9 +43,10 @@ type BackupService struct {
|
||||
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
auditLogService *audit_logs.AuditLogService
|
||||
backupCancelManager *backups_cancellation.BackupCancelManager
|
||||
taskCancelManager *task_cancellation.TaskCancelManager
|
||||
downloadTokenService *backups_download.DownloadTokenService
|
||||
backupSchedulerService *backuping.BackupsScheduler
|
||||
backupCleaner *backuping.BackupCleaner
|
||||
}
|
||||
|
||||
func (s *BackupService) AddBackupRemoveListener(listener backups_core.BackupRemoveListener) {
|
||||
@@ -189,7 +190,7 @@ func (s *BackupService) DeleteBackup(
|
||||
database.WorkspaceID,
|
||||
)
|
||||
|
||||
return s.deleteBackup(backup)
|
||||
return s.backupCleaner.DeleteBackup(backup)
|
||||
}
|
||||
|
||||
func (s *BackupService) GetBackup(backupID uuid.UUID) (*backups_core.Backup, error) {
|
||||
@@ -226,7 +227,7 @@ func (s *BackupService) CancelBackup(
|
||||
return errors.New("backup is not in progress")
|
||||
}
|
||||
|
||||
if err := s.backupCancelManager.CancelBackup(backupID); err != nil {
|
||||
if err := s.taskCancelManager.CancelTask(backupID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -292,29 +293,6 @@ func (s *BackupService) GetBackupFile(
|
||||
return reader, backup, database, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) deleteBackup(backup *backups_core.Backup) error {
|
||||
for _, listener := range s.backupRemoveListeners {
|
||||
if err := listener.OnBeforeBackupRemove(backup); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
storage, err := s.storageService.GetStorageByID(backup.StorageID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = storage.DeleteFile(s.fieldEncryptor, backup.ID)
|
||||
if err != nil {
|
||||
// we do not return error here, because sometimes clean up performed
|
||||
// before unavailable storage removal or change - therefore we should
|
||||
// proceed even in case of error
|
||||
s.logger.Error("Failed to delete backup file", "error", err)
|
||||
}
|
||||
|
||||
return s.backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
|
||||
dbBackupsInProgress, err := s.backupRepository.FindByDatabaseIdAndStatus(
|
||||
databaseID,
|
||||
@@ -336,7 +314,7 @@ func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
|
||||
}
|
||||
|
||||
for _, dbBackup := range dbBackups {
|
||||
err := s.deleteBackup(dbBackup)
|
||||
err := s.backupCleaner.DeleteBackup(dbBackup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -75,3 +75,23 @@ func WaitForBackupCompletion(
|
||||
|
||||
t.Logf("WaitForBackupCompletion: timeout waiting for backup to complete")
|
||||
}
|
||||
|
||||
// CreateTestBackup creates a simple test backup record for testing purposes
|
||||
func CreateTestBackup(databaseID, storageID uuid.UUID) *backups_core.Backup {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: databaseID,
|
||||
StorageID: storageID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10.5,
|
||||
BackupDurationMs: 1000,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
repo := &backups_core.BackupRepository{}
|
||||
if err := repo.Save(backup); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return backup
|
||||
}
|
||||
|
||||
@@ -1462,7 +1462,7 @@ func createTestDatabaseViaAPI(
|
||||
Type: databases.DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
package backups_config
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var backupConfigRepository = &BackupConfigRepository{}
|
||||
@@ -28,6 +32,21 @@ func GetBackupConfigService() *BackupConfigService {
|
||||
return backupConfigService
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
storages.GetStorageService().SetStorageDatabaseCounter(backupConfigService)
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
storages.GetStorageService().SetStorageDatabaseCounter(backupConfigService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,6 +31,11 @@ type BackupConfig struct {
|
||||
MaxFailedTriesCount int `json:"maxFailedTriesCount" gorm:"column:max_failed_tries_count;type:int;not null"`
|
||||
|
||||
Encryption BackupEncryption `json:"encryption" gorm:"column:encryption;type:text;not null;default:'NONE'"`
|
||||
|
||||
// MaxBackupSizeMB limits individual backup size. 0 = unlimited.
|
||||
MaxBackupSizeMB int64 `json:"maxBackupSizeMb" gorm:"column:max_backup_size_mb;type:int;not null"`
|
||||
// MaxBackupsTotalSizeMB limits total size of all backups. 0 = unlimited.
|
||||
MaxBackupsTotalSizeMB int64 `json:"maxBackupsTotalSizeMb" gorm:"column:max_backups_total_size_mb;type:int;not null"`
|
||||
}
|
||||
|
||||
func (h *BackupConfig) TableName() string {
|
||||
@@ -89,20 +94,30 @@ func (b *BackupConfig) Validate() error {
|
||||
return errors.New("encryption must be NONE or ENCRYPTED")
|
||||
}
|
||||
|
||||
if b.MaxBackupSizeMB < 0 {
|
||||
return errors.New("max backup size must be non-negative")
|
||||
}
|
||||
|
||||
if b.MaxBackupsTotalSizeMB < 0 {
|
||||
return errors.New("max backups total size must be non-negative")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *BackupConfig) Copy(newDatabaseID uuid.UUID) *BackupConfig {
|
||||
return &BackupConfig{
|
||||
DatabaseID: newDatabaseID,
|
||||
IsBackupsEnabled: b.IsBackupsEnabled,
|
||||
StorePeriod: b.StorePeriod,
|
||||
BackupIntervalID: uuid.Nil,
|
||||
BackupInterval: b.BackupInterval.Copy(),
|
||||
StorageID: b.StorageID,
|
||||
SendNotificationsOn: b.SendNotificationsOn,
|
||||
IsRetryIfFailed: b.IsRetryIfFailed,
|
||||
MaxFailedTriesCount: b.MaxFailedTriesCount,
|
||||
Encryption: b.Encryption,
|
||||
DatabaseID: newDatabaseID,
|
||||
IsBackupsEnabled: b.IsBackupsEnabled,
|
||||
StorePeriod: b.StorePeriod,
|
||||
BackupIntervalID: uuid.Nil,
|
||||
BackupInterval: b.BackupInterval.Copy(),
|
||||
StorageID: b.StorageID,
|
||||
SendNotificationsOn: b.SendNotificationsOn,
|
||||
IsRetryIfFailed: b.IsRetryIfFailed,
|
||||
MaxFailedTriesCount: b.MaxFailedTriesCount,
|
||||
Encryption: b.Encryption,
|
||||
MaxBackupSizeMB: b.MaxBackupSizeMB,
|
||||
MaxBackupsTotalSizeMB: b.MaxBackupsTotalSizeMB,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ func getTestPostgresConfig() *postgresql.PostgresqlDatabase {
|
||||
testDbName := "testdb"
|
||||
return &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
@@ -67,7 +67,7 @@ func getTestMariadbConfig() *mariadb.MariadbDatabase {
|
||||
testDbName := "testdb"
|
||||
return &mariadb.MariadbDatabase{
|
||||
Version: tools.MariadbVersion1011,
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
@@ -88,7 +88,7 @@ func getTestMongodbConfig() *mongodb.MongodbDatabase {
|
||||
|
||||
return &mongodb.MongodbDatabase{
|
||||
Version: tools.MongodbVersion7,
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "root",
|
||||
Password: "rootpassword",
|
||||
@@ -829,7 +829,7 @@ func createTestDatabaseViaAPI(
|
||||
Type: DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
|
||||
@@ -515,9 +515,13 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
|
||||
hasProcess := false
|
||||
hasAllPrivileges := false
|
||||
|
||||
// Escape underscores to match MariaDB's grant output format
|
||||
// MariaDB escapes _ as \_ in SHOW GRANTS output
|
||||
// Pattern matches either literal _ or escaped \_
|
||||
escapedDbName := strings.ReplaceAll(regexp.QuoteMeta(database), "_", `(_|\\_)`)
|
||||
dbPatternStr := fmt.Sprintf(
|
||||
`(?i)ON\s+[\x60'"]?%s[\x60'"]?\s*\.\s*\*`,
|
||||
regexp.QuoteMeta(database),
|
||||
escapedDbName,
|
||||
)
|
||||
dbPattern := regexp.MustCompile(dbPatternStr)
|
||||
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\s*\.\s*\*`)
|
||||
|
||||
@@ -694,6 +694,115 @@ func Test_TestConnection_DatabaseWithUnderscores_Success(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_TestConnection_DatabaseWithUnderscoresAndAllPrivileges_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MariadbVersion
|
||||
port string
|
||||
}{
|
||||
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
|
||||
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
|
||||
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
|
||||
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
|
||||
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
|
||||
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
|
||||
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
|
||||
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
|
||||
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
|
||||
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
|
||||
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMariadbContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
underscoreDbName := "test_all_db"
|
||||
|
||||
_, err := container.DB.Exec(
|
||||
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", underscoreDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(
|
||||
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
|
||||
)
|
||||
}()
|
||||
|
||||
underscoreDSN := fmt.Sprintf(
|
||||
"%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
container.Username,
|
||||
container.Password,
|
||||
container.Host,
|
||||
container.Port,
|
||||
underscoreDbName,
|
||||
)
|
||||
underscoreDB, err := sqlx.Connect("mysql", underscoreDSN)
|
||||
assert.NoError(t, err)
|
||||
defer underscoreDB.Close()
|
||||
|
||||
_, err = underscoreDB.Exec(`
|
||||
CREATE TABLE all_priv_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(`INSERT INTO all_priv_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
allPrivUsername := fmt.Sprintf("allpriv%s", uuid.New().String()[:8])
|
||||
allPrivPassword := "allprivpass123"
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
allPrivUsername,
|
||||
allPrivPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"GRANT ALL PRIVILEGES ON `%s`.* TO '%s'@'%%'",
|
||||
underscoreDbName,
|
||||
allPrivUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer dropUserSafe(underscoreDB, allPrivUsername)
|
||||
|
||||
mariadbModel := &MariadbDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: allPrivUsername,
|
||||
Password: allPrivPassword,
|
||||
Database: &underscoreDbName,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mariadbModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, mariadbModel.Privileges)
|
||||
assert.Contains(t, mariadbModel.Privileges, "SELECT")
|
||||
assert.Contains(t, mariadbModel.Privileges, "SHOW VIEW")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type MariadbContainer struct {
|
||||
Host string
|
||||
Port int
|
||||
@@ -714,7 +823,7 @@ func connectToMariadbContainer(
|
||||
}
|
||||
|
||||
dbName := "testdb"
|
||||
host := "127.0.0.1"
|
||||
host := config.GetEnv().TestLocalhost
|
||||
username := "root"
|
||||
password := "rootpassword"
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -397,7 +398,7 @@ func connectToMongodbContainer(
|
||||
}
|
||||
|
||||
dbName := "testdb"
|
||||
host := "127.0.0.1"
|
||||
host := config.GetEnv().TestLocalhost
|
||||
username := "root"
|
||||
password := "rootpassword"
|
||||
authDatabase := "admin"
|
||||
@@ -406,11 +407,18 @@ func connectToMongodbContainer(
|
||||
assert.NoError(t, err)
|
||||
|
||||
uri := fmt.Sprintf(
|
||||
"mongodb://%s:%s@%s:%d/%s?authSource=%s",
|
||||
username, password, host, portInt, dbName, authDatabase,
|
||||
"mongodb://%s:%s@%s:%d/%s?authSource=%s&serverSelectionTimeoutMS=5000&connectTimeoutMS=5000",
|
||||
username,
|
||||
password,
|
||||
host,
|
||||
portInt,
|
||||
dbName,
|
||||
authDatabase,
|
||||
)
|
||||
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
clientOptions := options.Client().ApplyURI(uri)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
if err != nil {
|
||||
|
||||
@@ -400,6 +400,7 @@ func HasPrivilege(privileges, priv string) bool {
|
||||
|
||||
func (m *MysqlDatabase) buildDSN(password string, database string) string {
|
||||
tlsConfig := "false"
|
||||
allowCleartext := ""
|
||||
|
||||
if m.IsHttps {
|
||||
err := mysql.RegisterTLSConfig("mysql-skip-verify", &tls.Config{
|
||||
@@ -411,16 +412,18 @@ func (m *MysqlDatabase) buildDSN(password string, database string) string {
|
||||
}
|
||||
|
||||
tlsConfig = "mysql-skip-verify"
|
||||
allowCleartext = "&allowCleartextPasswords=1"
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"%s:%s@tcp(%s:%d)/%s?parseTime=true&timeout=15s&tls=%s&charset=utf8mb4",
|
||||
"%s:%s@tcp(%s:%d)/%s?parseTime=true&timeout=15s&tls=%s&charset=utf8mb4%s",
|
||||
m.Username,
|
||||
password,
|
||||
m.Host,
|
||||
m.Port,
|
||||
database,
|
||||
tlsConfig,
|
||||
allowCleartext,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -486,9 +489,13 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
|
||||
hasProcess := false
|
||||
hasAllPrivileges := false
|
||||
|
||||
// Escape underscores to match MySQL's grant output format
|
||||
// MySQL escapes _ as \_ in SHOW GRANTS output
|
||||
// Pattern matches either literal _ or escaped \_
|
||||
escapedDbName := strings.ReplaceAll(regexp.QuoteMeta(database), "_", `(_|\\_)`)
|
||||
dbPatternStr := fmt.Sprintf(
|
||||
`(?i)ON\s+[\x60'"]?%s[\x60'"]?\s*\.\s*\*`,
|
||||
regexp.QuoteMeta(database),
|
||||
escapedDbName,
|
||||
)
|
||||
dbPattern := regexp.MustCompile(dbPatternStr)
|
||||
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\s*\.\s*\*`)
|
||||
|
||||
@@ -674,6 +674,112 @@ func Test_TestConnection_DatabaseWithUnderscores_Success(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_TestConnection_DatabaseWithUnderscoresAndAllPrivileges_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MysqlVersion
|
||||
port string
|
||||
}{
|
||||
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port},
|
||||
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port},
|
||||
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port},
|
||||
{"MySQL 9", tools.MysqlVersion9, env.TestMysql90Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMysqlContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
underscoreDbName := "test_all_db"
|
||||
|
||||
_, err := container.DB.Exec(
|
||||
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", underscoreDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(
|
||||
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
|
||||
)
|
||||
}()
|
||||
|
||||
underscoreDSN := fmt.Sprintf(
|
||||
"%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
container.Username,
|
||||
container.Password,
|
||||
container.Host,
|
||||
container.Port,
|
||||
underscoreDbName,
|
||||
)
|
||||
underscoreDB, err := sqlx.Connect("mysql", underscoreDSN)
|
||||
assert.NoError(t, err)
|
||||
defer underscoreDB.Close()
|
||||
|
||||
_, err = underscoreDB.Exec(`
|
||||
CREATE TABLE all_priv_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(`INSERT INTO all_priv_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
allPrivUsername := fmt.Sprintf("allpriv_%s", uuid.New().String()[:8])
|
||||
allPrivPassword := "allprivpass123"
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
allPrivUsername,
|
||||
allPrivPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"GRANT ALL PRIVILEGES ON `%s`.* TO '%s'@'%%'",
|
||||
underscoreDbName,
|
||||
allPrivUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(
|
||||
fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", allPrivUsername),
|
||||
)
|
||||
}()
|
||||
|
||||
mysqlModel := &MysqlDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: allPrivUsername,
|
||||
Password: allPrivPassword,
|
||||
Database: &underscoreDbName,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mysqlModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, mysqlModel.Privileges)
|
||||
assert.Contains(t, mysqlModel.Privileges, "SELECT")
|
||||
assert.Contains(t, mysqlModel.Privileges, "SHOW VIEW")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type MysqlContainer struct {
|
||||
Host string
|
||||
Port int
|
||||
@@ -694,7 +800,7 @@ func connectToMysqlContainer(
|
||||
}
|
||||
|
||||
dbName := "testdb"
|
||||
host := "127.0.0.1"
|
||||
host := config.GetEnv().TestLocalhost
|
||||
username := "root"
|
||||
password := "rootpassword"
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -85,6 +86,42 @@ func (p *PostgresqlDatabase) Validate() error {
|
||||
return errors.New("cpu count must be greater than 0")
|
||||
}
|
||||
|
||||
// Prevent Databasus from backing up itself
|
||||
// Databasus runs an internal PostgreSQL instance that should not be backed up through the UI
|
||||
// because it would expose internal metadata to non-system administrators.
|
||||
// To properly backup Databasus, see: https://databasus.com/faq#backup-databasus
|
||||
if p.Database != nil && *p.Database != "" {
|
||||
localhostHosts := []string{
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
"172.17.0.1",
|
||||
"host.docker.internal",
|
||||
"::1", // IPv6 loopback (equivalent to 127.0.0.1)
|
||||
"::", // IPv6 all interfaces (equivalent to 0.0.0.0)
|
||||
"0.0.0.0", // IPv4 all interfaces
|
||||
}
|
||||
|
||||
isLocalhost := false
|
||||
|
||||
for _, host := range localhostHosts {
|
||||
if strings.EqualFold(p.Host, host) {
|
||||
isLocalhost = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Also check if the host is in the entire 127.0.0.0/8 loopback range
|
||||
if strings.HasPrefix(p.Host, "127.") {
|
||||
isLocalhost = true
|
||||
}
|
||||
|
||||
if isLocalhost && strings.EqualFold(*p.Database, "databasus") {
|
||||
return errors.New(
|
||||
"backing up Databasus internal database is not allowed. To backup Databasus itself, see https://databasus.com/faq#backup-databasus",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -358,10 +395,13 @@ func (p *PostgresqlDatabase) IsUserReadOnly(
|
||||
//
|
||||
// This method performs the following operations atomically in a single transaction:
|
||||
// 1. Creates a PostgreSQL user with a UUID-based password
|
||||
// 2. Grants CONNECT privilege on the database
|
||||
// 3. Grants USAGE on all non-system schemas
|
||||
// 4. Grants SELECT on all existing tables and sequences
|
||||
// 5. Sets default privileges for future tables and sequences
|
||||
// 2. Revokes CREATE privilege on public schema from PUBLIC role
|
||||
// 3. Grants CONNECT privilege on the database
|
||||
// 4. Discovers all user-created schemas
|
||||
// 5. Grants USAGE on all non-system schemas
|
||||
// 6. Grants SELECT on all existing tables and sequences
|
||||
// 7. Sets default privileges for future tables and sequences
|
||||
// 8. Verifies user creation before committing
|
||||
//
|
||||
// Security features:
|
||||
// - Username format: "databasus-{8-char-uuid}" for uniqueness
|
||||
@@ -451,33 +491,56 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
return "", "", fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
// Step 1.5: Revoke CREATE privilege from PUBLIC role on public schema
|
||||
// Step 2: Check if public schema exists and revoke CREATE privilege if it does
|
||||
// This is necessary because all PostgreSQL users inherit CREATE privilege on the
|
||||
// public schema through the PUBLIC role. This is a one-time operation that affects
|
||||
// the entire database, making it more secure by default.
|
||||
// Note: This only affects the public schema; other schemas are unaffected.
|
||||
_, err = tx.Exec(ctx, `REVOKE CREATE ON SCHEMA public FROM PUBLIC`)
|
||||
if err != nil {
|
||||
logger.Error("Failed to revoke CREATE on public from PUBLIC", "error", err)
|
||||
if !strings.Contains(err.Error(), "schema \"public\" does not exist") &&
|
||||
!strings.Contains(err.Error(), "permission denied") {
|
||||
return "", "", fmt.Errorf("failed to revoke CREATE from PUBLIC: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Now revoke from the specific user as well (belt and suspenders)
|
||||
_, err = tx.Exec(ctx, fmt.Sprintf(`REVOKE CREATE ON SCHEMA public FROM "%s"`, baseUsername))
|
||||
if err != nil {
|
||||
logger.Error(
|
||||
"Failed to revoke CREATE on public schema from user",
|
||||
"error",
|
||||
err,
|
||||
"username",
|
||||
baseUsername,
|
||||
var publicSchemaExists bool
|
||||
err = tx.QueryRow(ctx, `
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM information_schema.schemata
|
||||
WHERE schema_name = 'public'
|
||||
)
|
||||
`).Scan(&publicSchemaExists)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to check if public schema exists: %w", err)
|
||||
}
|
||||
|
||||
// Step 2: Grant database connection privilege and revoke TEMP
|
||||
if publicSchemaExists {
|
||||
// Revoke CREATE from PUBLIC role (affects all users)
|
||||
_, err = tx.Exec(ctx, `REVOKE CREATE ON SCHEMA public FROM PUBLIC`)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "permission denied") {
|
||||
logger.Warn(
|
||||
"Failed to revoke CREATE on public from PUBLIC (permission denied)",
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
} else {
|
||||
return "", "", fmt.Errorf("failed to revoke CREATE from PUBLIC on existing public schema: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Now revoke from the specific user as well (belt and suspenders)
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(`REVOKE CREATE ON SCHEMA public FROM "%s"`, baseUsername),
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn(
|
||||
"Failed to revoke CREATE on public schema from user",
|
||||
"error",
|
||||
err,
|
||||
"username",
|
||||
baseUsername,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
logger.Info("Public schema does not exist, skipping CREATE privilege revocation")
|
||||
}
|
||||
|
||||
// Step 3: Grant database connection privilege and revoke TEMP
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(`GRANT CONNECT ON DATABASE "%s" TO "%s"`, *p.Database, baseUsername),
|
||||
@@ -501,7 +564,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
logger.Warn("Failed to revoke TEMP privilege", "error", err, "username", baseUsername)
|
||||
}
|
||||
|
||||
// Step 3: Discover all user-created schemas
|
||||
// Step 4: Discover all user-created schemas
|
||||
rows, err := tx.Query(ctx, `
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
@@ -526,7 +589,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
return "", "", fmt.Errorf("error iterating schemas: %w", err)
|
||||
}
|
||||
|
||||
// Step 4: Grant USAGE on each schema and explicitly prevent CREATE
|
||||
// Step 5: Grant USAGE on each schema and explicitly prevent CREATE
|
||||
for _, schema := range schemas {
|
||||
// Revoke CREATE specifically (handles inheritance from PUBLIC role)
|
||||
_, err = tx.Exec(
|
||||
@@ -555,7 +618,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
}
|
||||
}
|
||||
|
||||
// Step 5: Grant SELECT on ALL existing tables and sequences
|
||||
// Step 6: Grant SELECT on ALL existing tables and sequences
|
||||
grantSelectSQL := fmt.Sprintf(`
|
||||
DO $$
|
||||
DECLARE
|
||||
@@ -577,7 +640,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
return "", "", fmt.Errorf("failed to grant select on tables: %w", err)
|
||||
}
|
||||
|
||||
// Step 6: Set default privileges for FUTURE tables and sequences
|
||||
// Step 7: Set default privileges for FUTURE tables and sequences
|
||||
defaultPrivilegesSQL := fmt.Sprintf(`
|
||||
DO $$
|
||||
DECLARE
|
||||
@@ -599,7 +662,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
return "", "", fmt.Errorf("failed to set default privileges: %w", err)
|
||||
}
|
||||
|
||||
// Step 7: Verify user creation before committing
|
||||
// Step 8: Verify user creation before committing
|
||||
var verifyUsername string
|
||||
err = tx.QueryRow(ctx, fmt.Sprintf(`SELECT rolname FROM pg_roles WHERE rolname = '%s'`, baseUsername)).
|
||||
Scan(&verifyUsername)
|
||||
@@ -815,7 +878,15 @@ func checkBackupPermissions(
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot check SELECT privileges: %w", err)
|
||||
// If the user doesn't have USAGE on the schema, has_table_privilege will fail
|
||||
// with "permission denied for schema". This means they definitely don't have
|
||||
// SELECT privileges, so treat this as missing permissions rather than an error.
|
||||
var pgErr *pgconn.PgError
|
||||
if errors.As(err, &pgErr) && pgErr.Code == "42501" { // insufficient_privilege
|
||||
selectableTableCount = 0
|
||||
} else {
|
||||
return fmt.Errorf("cannot check SELECT privileges: %w", err)
|
||||
}
|
||||
}
|
||||
if selectableTableCount == 0 {
|
||||
missingPrivileges = append(missingPrivileges, "SELECT on tables")
|
||||
|
||||
@@ -599,6 +599,10 @@ func Test_CreateReadOnlyUser_DatabaseNameWithDash_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_Supabase_UserCanReadButNotWrite(t *testing.T) {
|
||||
if config.GetEnv().IsSkipExternalResourcesTests {
|
||||
t.Skip("Skipping Supabase test: IS_SKIP_EXTERNAL_RESOURCES_TESTS is true")
|
||||
}
|
||||
|
||||
env := config.GetEnv()
|
||||
|
||||
if env.TestSupabaseHost == "" {
|
||||
@@ -705,6 +709,607 @@ func Test_CreateReadOnlyUser_Supabase_UserCanReadButNotWrite(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_WithPublicSchema_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
port string
|
||||
}{
|
||||
{"PostgreSQL 12", "12", env.TestPostgres12Port},
|
||||
{"PostgreSQL 13", "13", env.TestPostgres13Port},
|
||||
{"PostgreSQL 14", "14", env.TestPostgres14Port},
|
||||
{"PostgreSQL 15", "15", env.TestPostgres15Port},
|
||||
{"PostgreSQL 16", "16", env.TestPostgres16Port},
|
||||
{"PostgreSQL 17", "17", env.TestPostgres17Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToPostgresContainer(t, tc.port)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`
|
||||
DROP TABLE IF EXISTS public_schema_test CASCADE;
|
||||
CREATE TABLE public_schema_test (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO public_schema_test (data) VALUES ('test1'), ('test2');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, username)
|
||||
assert.NotEmpty(t, password)
|
||||
assert.True(t, strings.HasPrefix(username, "databasus-"))
|
||||
|
||||
readOnlyModel := &PostgresqlDatabase{
|
||||
Version: pgModel.Version,
|
||||
Host: pgModel.Host,
|
||||
Port: pgModel.Port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: pgModel.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
|
||||
ctx,
|
||||
logger,
|
||||
nil,
|
||||
uuid.New(),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, isReadOnly, "User should be read-only")
|
||||
assert.Empty(t, privileges, "Read-only user should have no write privileges")
|
||||
|
||||
readOnlyDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
username,
|
||||
password,
|
||||
container.Database,
|
||||
)
|
||||
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var count int
|
||||
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM public_schema_test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
_, err = readOnlyConn.Exec(
|
||||
"INSERT INTO public_schema_test (data) VALUES ('should-fail')",
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("CREATE TABLE public.hack_table (id INT)")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to drop owned objects: %v", err)
|
||||
}
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_WithoutPublicSchema_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
port string
|
||||
}{
|
||||
{"PostgreSQL 12", "12", env.TestPostgres12Port},
|
||||
{"PostgreSQL 13", "13", env.TestPostgres13Port},
|
||||
{"PostgreSQL 14", "14", env.TestPostgres14Port},
|
||||
{"PostgreSQL 15", "15", env.TestPostgres15Port},
|
||||
{"PostgreSQL 16", "16", env.TestPostgres16Port},
|
||||
{"PostgreSQL 17", "17", env.TestPostgres17Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToPostgresContainer(t, tc.port)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`
|
||||
DROP SCHEMA IF EXISTS public CASCADE;
|
||||
DROP SCHEMA IF EXISTS app_schema CASCADE;
|
||||
DROP SCHEMA IF EXISTS data_schema CASCADE;
|
||||
CREATE SCHEMA app_schema;
|
||||
CREATE SCHEMA data_schema;
|
||||
CREATE TABLE app_schema.users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
username TEXT NOT NULL
|
||||
);
|
||||
CREATE TABLE data_schema.records (
|
||||
id SERIAL PRIMARY KEY,
|
||||
info TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO app_schema.users (username) VALUES ('user1'), ('user2');
|
||||
INSERT INTO data_schema.records (info) VALUES ('record1'), ('record2');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err, "CreateReadOnlyUser should succeed without public schema")
|
||||
assert.NotEmpty(t, username)
|
||||
assert.NotEmpty(t, password)
|
||||
assert.True(t, strings.HasPrefix(username, "databasus-"))
|
||||
|
||||
readOnlyModel := &PostgresqlDatabase{
|
||||
Version: pgModel.Version,
|
||||
Host: pgModel.Host,
|
||||
Port: pgModel.Port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: pgModel.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
|
||||
ctx,
|
||||
logger,
|
||||
nil,
|
||||
uuid.New(),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, isReadOnly, "User should be read-only")
|
||||
assert.Empty(t, privileges, "Read-only user should have no write privileges")
|
||||
|
||||
readOnlyDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
username,
|
||||
password,
|
||||
container.Database,
|
||||
)
|
||||
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var userCount int
|
||||
err = readOnlyConn.Get(&userCount, "SELECT COUNT(*) FROM app_schema.users")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, userCount)
|
||||
|
||||
var recordCount int
|
||||
err = readOnlyConn.Get(&recordCount, "SELECT COUNT(*) FROM data_schema.records")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, recordCount)
|
||||
|
||||
_, err = readOnlyConn.Exec(
|
||||
"INSERT INTO app_schema.users (username) VALUES ('should-fail')",
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("CREATE TABLE app_schema.hack_table (id INT)")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("CREATE TABLE data_schema.hack_table (id INT)")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to drop owned objects: %v", err)
|
||||
}
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`
|
||||
DROP SCHEMA IF EXISTS app_schema CASCADE;
|
||||
DROP SCHEMA IF EXISTS data_schema CASCADE;
|
||||
CREATE SCHEMA IF NOT EXISTS public;
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_PublicSchemaExistsButNoPermissions_ReturnsError(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
port string
|
||||
}{
|
||||
{"PostgreSQL 12", "12", env.TestPostgres12Port},
|
||||
{"PostgreSQL 13", "13", env.TestPostgres13Port},
|
||||
{"PostgreSQL 14", "14", env.TestPostgres14Port},
|
||||
{"PostgreSQL 15", "15", env.TestPostgres15Port},
|
||||
{"PostgreSQL 16", "16", env.TestPostgres16Port},
|
||||
{"PostgreSQL 17", "17", env.TestPostgres17Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToPostgresContainer(t, tc.port)
|
||||
defer container.DB.Close()
|
||||
|
||||
limitedAdminUsername := fmt.Sprintf("limited_admin_%s", uuid.New().String()[:8])
|
||||
limitedAdminPassword := "limited_password_123"
|
||||
|
||||
_, err := container.DB.Exec(`
|
||||
CREATE SCHEMA IF NOT EXISTS public;
|
||||
DROP TABLE IF EXISTS public.permission_test_table CASCADE;
|
||||
CREATE TABLE public.permission_test_table (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO public.permission_test_table (data) VALUES ('test1');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`GRANT CREATE ON SCHEMA public TO PUBLIC`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN CREATEROLE`,
|
||||
limitedAdminUsername,
|
||||
limitedAdminPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
|
||||
container.Database,
|
||||
limitedAdminUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(
|
||||
fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, limitedAdminUsername),
|
||||
)
|
||||
_, _ = container.DB.Exec(
|
||||
fmt.Sprintf(`DROP USER IF EXISTS "%s"`, limitedAdminUsername),
|
||||
)
|
||||
_, _ = container.DB.Exec(`REVOKE CREATE ON SCHEMA public FROM PUBLIC`)
|
||||
}()
|
||||
|
||||
limitedAdminDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
limitedAdminUsername,
|
||||
limitedAdminPassword,
|
||||
container.Database,
|
||||
)
|
||||
limitedAdminConn, err := sqlx.Connect("postgres", limitedAdminDSN)
|
||||
assert.NoError(t, err)
|
||||
defer limitedAdminConn.Close()
|
||||
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Version: tools.GetPostgresqlVersionEnum(tc.version),
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: limitedAdminUsername,
|
||||
Password: limitedAdminPassword,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.Error(
|
||||
t,
|
||||
err,
|
||||
"CreateReadOnlyUser should fail when admin lacks permissions to secure public schema",
|
||||
)
|
||||
if err != nil {
|
||||
errorMsg := err.Error()
|
||||
hasExpectedError := strings.Contains(
|
||||
errorMsg,
|
||||
"failed to revoke CREATE from PUBLIC on existing public schema",
|
||||
) ||
|
||||
strings.Contains(errorMsg, "permission denied for schema public") ||
|
||||
strings.Contains(errorMsg, "failed to grant")
|
||||
assert.True(
|
||||
t,
|
||||
hasExpectedError,
|
||||
"Error should indicate permission issues with public schema, got: %s",
|
||||
errorMsg,
|
||||
)
|
||||
}
|
||||
assert.Empty(t, username)
|
||||
assert.Empty(t, password)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Validate_WhenLocalhostAndDatabasus_ReturnsError(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
username string
|
||||
database string
|
||||
}{
|
||||
{
|
||||
name: "localhost with databasus db",
|
||||
host: "localhost",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1 with databasus db",
|
||||
host: "127.0.0.1",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "172.17.0.1 with databasus db",
|
||||
host: "172.17.0.1",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "host.docker.internal with databasus db",
|
||||
host: "host.docker.internal",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "LOCALHOST (uppercase) with DATABASUS db",
|
||||
host: "LOCALHOST",
|
||||
username: "POSTGRES",
|
||||
database: "DATABASUS",
|
||||
},
|
||||
{
|
||||
name: "LocalHost (mixed case) with DataBasus db",
|
||||
host: "LocalHost",
|
||||
username: "anyuser",
|
||||
database: "DataBasus",
|
||||
},
|
||||
{
|
||||
name: "localhost with databasus and any username",
|
||||
host: "localhost",
|
||||
username: "myuser",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "::1 (IPv6 loopback) with databasus db",
|
||||
host: "::1",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: ":: (IPv6 all interfaces) with databasus db",
|
||||
host: "::",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "::1 (uppercase) with DATABASUS db",
|
||||
host: "::1",
|
||||
username: "POSTGRES",
|
||||
database: "DATABASUS",
|
||||
},
|
||||
{
|
||||
name: "0.0.0.0 (all IPv4 interfaces) with databasus db",
|
||||
host: "0.0.0.0",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "127.0.0.2 (loopback range) with databasus db",
|
||||
host: "127.0.0.2",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "127.255.255.255 (end of loopback range) with databasus db",
|
||||
host: "127.255.255.255",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Host: tc.host,
|
||||
Port: 5437,
|
||||
Username: tc.username,
|
||||
Password: "somepassword",
|
||||
Database: &tc.database,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
err := pgModel.Validate()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "backing up Databasus internal database is not allowed")
|
||||
assert.Contains(t, err.Error(), "https://databasus.com/faq#backup-databasus")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Validate_WhenNotLocalhostOrNotDatabasus_ValidatesSuccessfully(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
username string
|
||||
database string
|
||||
}{
|
||||
{
|
||||
name: "different host (remote server) with databasus db",
|
||||
host: "192.168.1.100",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
{
|
||||
name: "different database name on localhost",
|
||||
host: "localhost",
|
||||
username: "postgres",
|
||||
database: "myapp",
|
||||
},
|
||||
{
|
||||
name: "all different",
|
||||
host: "db.example.com",
|
||||
username: "appuser",
|
||||
database: "production",
|
||||
},
|
||||
{
|
||||
name: "localhost with postgres database",
|
||||
host: "localhost",
|
||||
username: "postgres",
|
||||
database: "postgres",
|
||||
},
|
||||
{
|
||||
name: "remote host with databasus db name (allowed)",
|
||||
host: "db.example.com",
|
||||
username: "postgres",
|
||||
database: "databasus",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Host: tc.host,
|
||||
Port: 5432,
|
||||
Username: tc.username,
|
||||
Password: "somepassword",
|
||||
Database: &tc.database,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
err := pgModel.Validate()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Validate_WhenDatabaseIsNil_ValidatesSuccessfully(t *testing.T) {
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Host: "localhost",
|
||||
Port: 5437,
|
||||
Username: "postgres",
|
||||
Password: "somepassword",
|
||||
Database: nil,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
err := pgModel.Validate()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenDatabaseIsEmpty_ValidatesSuccessfully(t *testing.T) {
|
||||
emptyDb := ""
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Host: "localhost",
|
||||
Port: 5437,
|
||||
Username: "postgres",
|
||||
Password: "somepassword",
|
||||
Database: &emptyDb,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
err := pgModel.Validate()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_WhenRequiredFieldsMissing_ReturnsError(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
model *PostgresqlDatabase
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "missing host",
|
||||
model: &PostgresqlDatabase{
|
||||
Host: "",
|
||||
Port: 5432,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
CpuCount: 1,
|
||||
},
|
||||
expectedError: "host is required",
|
||||
},
|
||||
{
|
||||
name: "missing port",
|
||||
model: &PostgresqlDatabase{
|
||||
Host: "localhost",
|
||||
Port: 0,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
CpuCount: 1,
|
||||
},
|
||||
expectedError: "port is required",
|
||||
},
|
||||
{
|
||||
name: "missing username",
|
||||
model: &PostgresqlDatabase{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "",
|
||||
Password: "pass",
|
||||
CpuCount: 1,
|
||||
},
|
||||
expectedError: "username is required",
|
||||
},
|
||||
{
|
||||
name: "missing password",
|
||||
model: &PostgresqlDatabase{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "user",
|
||||
Password: "",
|
||||
CpuCount: 1,
|
||||
},
|
||||
expectedError: "password is required",
|
||||
},
|
||||
{
|
||||
name: "invalid cpu count",
|
||||
model: &PostgresqlDatabase{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
CpuCount: 0,
|
||||
},
|
||||
expectedError: "cpu count must be greater than 0",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.model.Validate()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tc.expectedError)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type PostgresContainer struct {
|
||||
Host string
|
||||
Port int
|
||||
@@ -718,7 +1323,7 @@ func connectToPostgresContainer(t *testing.T, port string) *PostgresContainer {
|
||||
dbName := "testdb"
|
||||
password := "testpassword"
|
||||
username := "testuser"
|
||||
host := "localhost"
|
||||
host := config.GetEnv().TestLocalhost
|
||||
|
||||
portInt, err := strconv.Atoi(port)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package databases
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
@@ -37,7 +40,22 @@ func GetDatabaseController() *DatabaseController {
|
||||
return databaseController
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(databaseService)
|
||||
notifiers.GetNotifierService().SetNotifierDatabaseCounter(databaseService)
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(databaseService)
|
||||
notifiers.GetNotifierService().SetNotifierDatabaseCounter(databaseService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ func GetTestPostgresConfig() *postgresql.PostgresqlDatabase {
|
||||
testDbName := "testdb"
|
||||
return &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
@@ -48,7 +48,7 @@ func GetTestMariadbConfig() *mariadb.MariadbDatabase {
|
||||
testDbName := "testdb"
|
||||
return &mariadb.MariadbDatabase{
|
||||
Version: tools.MariadbVersion1011,
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
@@ -69,7 +69,7 @@ func GetTestMongodbConfig() *mongodb.MongodbDatabase {
|
||||
|
||||
return &mongodb.MongodbDatabase{
|
||||
Version: tools.MongodbVersion7,
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "root",
|
||||
Password: "rootpassword",
|
||||
|
||||
@@ -2,30 +2,47 @@ package healthcheck_attempt
|
||||
|
||||
import (
|
||||
"context"
|
||||
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
|
||||
)
|
||||
|
||||
type HealthcheckAttemptBackgroundService struct {
|
||||
healthcheckConfigService *healthcheck_config.HealthcheckConfigService
|
||||
checkDatabaseHealthUseCase *CheckDatabaseHealthUseCase
|
||||
logger *slog.Logger
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *HealthcheckAttemptBackgroundService) Run(ctx context.Context) {
|
||||
// first healthcheck immediately
|
||||
s.checkDatabases()
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.checkDatabases()
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
|
||||
// first healthcheck immediately
|
||||
s.checkDatabases()
|
||||
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.checkDatabases()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package healthcheck_attempt
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"databasus-backend/internal/features/databases"
|
||||
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
@@ -22,9 +25,11 @@ var checkDatabaseHealthUseCase = &CheckDatabaseHealthUseCase{
|
||||
}
|
||||
|
||||
var healthcheckAttemptBackgroundService = &HealthcheckAttemptBackgroundService{
|
||||
healthcheck_config.GetHealthcheckConfigService(),
|
||||
checkDatabaseHealthUseCase,
|
||||
logger.GetLogger(),
|
||||
healthcheckConfigService: healthcheck_config.GetHealthcheckConfigService(),
|
||||
checkDatabaseHealthUseCase: checkDatabaseHealthUseCase,
|
||||
logger: logger.GetLogger(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
var healthcheckAttemptController = &HealthcheckAttemptController{
|
||||
healthcheckAttemptService,
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package healthcheck_config
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/databases"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
@@ -27,8 +30,23 @@ func GetHealthcheckConfigController() *HealthcheckConfigController {
|
||||
return healthcheckConfigController
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
databases.
|
||||
GetDatabaseService().
|
||||
AddDbCreationListener(healthcheckConfigService)
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
databases.
|
||||
GetDatabaseService().
|
||||
AddDbCreationListener(healthcheckConfigService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package notifiers
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
@@ -32,6 +35,22 @@ func GetNotifierService() *NotifierService {
|
||||
func GetNotifierRepository() *NotifierRepository {
|
||||
return notifierRepository
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(notifierService)
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(notifierService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"mime"
|
||||
"net"
|
||||
"net/smtp"
|
||||
"time"
|
||||
@@ -115,16 +116,34 @@ func (e *EmailNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeRFC2047 encodes a string using RFC 2047 MIME encoding for email headers
|
||||
// This ensures compatibility with SMTP servers that don't support SMTPUTF8
|
||||
func encodeRFC2047(s string) string {
|
||||
// mime.QEncoding handles UTF-8 → =?UTF-8?Q?...?= encoding
|
||||
// This allows non-ASCII characters (emojis, accents, etc.) in email headers
|
||||
// while maintaining compatibility with all SMTP servers
|
||||
return mime.QEncoding.Encode("UTF-8", s)
|
||||
}
|
||||
|
||||
func (e *EmailNotifier) buildEmailContent(heading, message, from string) []byte {
|
||||
subject := fmt.Sprintf("Subject: %s\r\n", heading)
|
||||
mime := fmt.Sprintf(
|
||||
// Encode Subject header using RFC 2047 to avoid SMTPUTF8 requirement
|
||||
// This ensures compatibility with SMTP servers that don't support SMTPUTF8
|
||||
encodedSubject := encodeRFC2047(heading)
|
||||
subject := fmt.Sprintf("Subject: %s\r\n", encodedSubject)
|
||||
|
||||
mimeHeaders := fmt.Sprintf(
|
||||
"MIME-version: 1.0;\nContent-Type: %s; charset=\"%s\";\n\n",
|
||||
MIMETypeHTML,
|
||||
MIMECharsetUTF8,
|
||||
)
|
||||
fromHeader := fmt.Sprintf("From: %s\r\n", from)
|
||||
|
||||
// Encode From header display name if it contains non-ASCII
|
||||
encodedFrom := encodeRFC2047(from)
|
||||
fromHeader := fmt.Sprintf("From: %s\r\n", encodedFrom)
|
||||
|
||||
toHeader := fmt.Sprintf("To: %s\r\n", e.TargetEmail)
|
||||
return []byte(fromHeader + toHeader + subject + mime + message)
|
||||
|
||||
return []byte(fromHeader + toHeader + subject + mimeHeaders + message)
|
||||
}
|
||||
|
||||
func (e *EmailNotifier) sendImplicitTLS(
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
package restores
|
||||
|
||||
import (
|
||||
"context"
|
||||
"databasus-backend/internal/features/restores/enums"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
type RestoreBackgroundService struct {
|
||||
restoreRepository *RestoreRepository
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *RestoreBackgroundService) Run(ctx context.Context) {
|
||||
if err := s.failRestoresInProgress(); err != nil {
|
||||
s.logger.Error("Failed to fail restores in progress", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RestoreBackgroundService) failRestoresInProgress() error {
|
||||
restoresInProgress, err := s.restoreRepository.FindByStatus(enums.RestoreStatusInProgress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, restore := range restoresInProgress {
|
||||
failMessage := "Restore failed due to application restart"
|
||||
restore.Status = enums.RestoreStatusFailed
|
||||
restore.FailMessage = &failMessage
|
||||
|
||||
if err := s.restoreRepository.Save(restore); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package restores
|
||||
|
||||
import (
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
"net/http"
|
||||
|
||||
@@ -15,6 +16,7 @@ type RestoreController struct {
|
||||
func (c *RestoreController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
router.GET("/restores/:backupId", c.GetRestores)
|
||||
router.POST("/restores/:backupId/restore", c.RestoreBackup)
|
||||
router.POST("/restores/cancel/:restoreId", c.CancelRestore)
|
||||
}
|
||||
|
||||
// GetRestores
|
||||
@@ -23,7 +25,7 @@ func (c *RestoreController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// @Tags restores
|
||||
// @Produce json
|
||||
// @Param backupId path string true "Backup ID"
|
||||
// @Success 200 {array} models.Restore
|
||||
// @Success 200 {array} restores_core.Restore
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Router /restores/{backupId} [get]
|
||||
@@ -71,7 +73,7 @@ func (c *RestoreController) RestoreBackup(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var requestDTO RestoreBackupRequest
|
||||
var requestDTO restores_core.RestoreBackupRequest
|
||||
if err := ctx.ShouldBindJSON(&requestDTO); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -84,3 +86,33 @@ func (c *RestoreController) RestoreBackup(ctx *gin.Context) {
|
||||
|
||||
ctx.JSON(http.StatusOK, gin.H{"message": "restore started successfully"})
|
||||
}
|
||||
|
||||
// CancelRestore
|
||||
// @Summary Cancel an in-progress restore
|
||||
// @Description Cancel a restore that is currently in progress
|
||||
// @Tags restores
|
||||
// @Param restoreId path string true "Restore ID"
|
||||
// @Success 204
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Router /restores/cancel/{restoreId} [post]
|
||||
func (c *RestoreController) CancelRestore(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
restoreID, err := uuid.Parse(ctx.Param("restoreId"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid restore ID"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.restoreService.CancelRestore(user, restoreID); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
env_config "databasus-backend/internal/config"
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
@@ -24,16 +24,19 @@ import (
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/databases/databases/mysql"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
"databasus-backend/internal/features/restores/models"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/restores/restoring"
|
||||
"databasus-backend/internal/features/storages"
|
||||
local_storage "databasus-backend/internal/features/storages/models/local"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
users_dto "databasus-backend/internal/features/users/dto"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_models "databasus-backend/internal/features/workspaces/models"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
test_utils "databasus-backend/internal/util/testing"
|
||||
"databasus-backend/internal/util/tools"
|
||||
@@ -46,7 +49,7 @@ func Test_GetRestores_WhenUserIsWorkspaceMember_RestoresReturned(t *testing.T) {
|
||||
|
||||
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
|
||||
var restores []*models.Restore
|
||||
var restores []*restores_core.Restore
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
@@ -90,7 +93,7 @@ func Test_GetRestores_WhenUserIsGlobalAdmin_RestoresReturned(t *testing.T) {
|
||||
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
|
||||
var restores []*models.Restore
|
||||
var restores []*restores_core.Restore
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
@@ -105,15 +108,19 @@ func Test_GetRestores_WhenUserIsGlobalAdmin_RestoresReturned(t *testing.T) {
|
||||
|
||||
func Test_RestoreBackup_WhenUserIsWorkspaceMember_RestoreInitiated(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
|
||||
_, cleanup := SetupMockRestoreNode(t)
|
||||
defer cleanup()
|
||||
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
|
||||
request := RestoreBackupRequest{
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: env_config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
@@ -141,10 +148,10 @@ func Test_RestoreBackup_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing
|
||||
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
|
||||
request := RestoreBackupRequest{
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: env_config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
@@ -165,15 +172,19 @@ func Test_RestoreBackup_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing
|
||||
|
||||
func Test_RestoreBackup_WithIsExcludeExtensions_FlagPassedCorrectly(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
|
||||
_, cleanup := SetupMockRestoreNode(t)
|
||||
defer cleanup()
|
||||
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
|
||||
request := RestoreBackupRequest{
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: env_config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
@@ -195,15 +206,19 @@ func Test_RestoreBackup_WithIsExcludeExtensions_FlagPassedCorrectly(t *testing.T
|
||||
|
||||
func Test_RestoreBackup_AuditLogWritten(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
|
||||
_, cleanup := SetupMockRestoreNode(t)
|
||||
defer cleanup()
|
||||
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
|
||||
request := RestoreBackupRequest{
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: env_config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
@@ -272,18 +287,25 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
|
||||
// Setup mock node for tests that skip disk validation and reach scheduler
|
||||
if !tc.expectDiskValidated {
|
||||
_, cleanup := SetupMockRestoreNode(t)
|
||||
defer cleanup()
|
||||
}
|
||||
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
var backup *backups_core.Backup
|
||||
var request RestoreBackupRequest
|
||||
var request restores_core.RestoreBackupRequest
|
||||
|
||||
if tc.dbType == databases.DatabaseTypePostgres {
|
||||
_, backup = createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
request = RestoreBackupRequest{
|
||||
request = restores_core.RestoreBackupRequest{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: env_config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
@@ -310,10 +332,10 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
backup = createTestBackup(mysqlDB, owner)
|
||||
request = RestoreBackupRequest{
|
||||
request = restores_core.RestoreBackupRequest{
|
||||
MysqlDatabase: &mysql.MysqlDatabase{
|
||||
Version: tools.MysqlVersion80,
|
||||
Host: "localhost",
|
||||
Host: env_config.GetEnv().TestLocalhost,
|
||||
Port: 3306,
|
||||
Username: "root",
|
||||
Password: "password",
|
||||
@@ -353,16 +375,187 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
router := workspaces_testing.CreateTestRouter(
|
||||
workspaces_controllers.GetWorkspaceController(),
|
||||
workspaces_controllers.GetMembershipController(),
|
||||
databases.GetDatabaseController(),
|
||||
backups_config.GetBackupConfigController(),
|
||||
backups.GetBackupController(),
|
||||
GetRestoreController(),
|
||||
func Test_CancelRestore_InProgressRestore_SuccessfullyCancelled(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
tasks_cancellation.SetupDependencies()
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := createTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backupRepo := backups_core.BackupRepository{}
|
||||
backups, _ := backupRepo.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
restoreRepo := restores_core.RestoreRepository{}
|
||||
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
restores, _ = restoreRepo.FindByStatus(restores_core.RestoreStatusCanceled)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
||||
|
||||
mockUsecase := &restoring.MockBlockingRestoreUsecase{
|
||||
StartedChan: make(chan bool, 1),
|
||||
}
|
||||
restorerNode := restoring.CreateTestRestorerNodeWithUsecase(mockUsecase)
|
||||
|
||||
cancelNode := restoring.StartRestorerNodeForTest(t, restorerNode)
|
||||
defer cancelNode()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
restoreRequest := restores_core.RestoreBackupRequest{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: env_config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
},
|
||||
}
|
||||
|
||||
var restoreResponse map[string]interface{}
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
|
||||
"Bearer "+user.Token,
|
||||
restoreRequest,
|
||||
http.StatusOK,
|
||||
&restoreResponse,
|
||||
)
|
||||
return router
|
||||
|
||||
select {
|
||||
case <-mockUsecase.StartedChan:
|
||||
t.Log("Restore started and is blocking")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Restore did not start within timeout")
|
||||
}
|
||||
|
||||
restoreRepo := &restores_core.RestoreRepository{}
|
||||
restores, err := restoreRepo.FindByBackupID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(t, len(restores), 0, "At least one restore should exist")
|
||||
|
||||
var restoreID uuid.UUID
|
||||
for _, r := range restores {
|
||||
if r.Status == restores_core.RestoreStatusInProgress {
|
||||
restoreID = r.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.NotEqual(t, uuid.Nil, restoreID, "Should find an in-progress restore")
|
||||
|
||||
resp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/restores/cancel/%s", restoreID.String()),
|
||||
"Bearer "+user.Token,
|
||||
nil,
|
||||
http.StatusNoContent,
|
||||
)
|
||||
|
||||
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
|
||||
|
||||
deadline := time.Now().UTC().Add(3 * time.Second)
|
||||
var restore *restores_core.Restore
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
restore, err = restoreRepo.FindByID(restoreID)
|
||||
assert.NoError(t, err)
|
||||
if restore.Status == restores_core.RestoreStatusCanceled {
|
||||
break
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
assert.Equal(t, restores_core.RestoreStatusCanceled, restore.Status)
|
||||
|
||||
auditLogService := audit_logs.GetAuditLogService()
|
||||
auditLogs, err := auditLogService.GetWorkspaceAuditLogs(
|
||||
workspace.ID,
|
||||
&audit_logs.GetAuditLogsRequest{Limit: 100, Offset: 0},
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
foundCancelLog := false
|
||||
for _, log := range auditLogs.AuditLogs {
|
||||
if strings.Contains(log.Message, "Restore cancelled") &&
|
||||
strings.Contains(log.Message, database.Name) {
|
||||
foundCancelLog = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, foundCancelLog, "Cancel audit log should be created")
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_RestoreBackup_WithParallelRestoreInProgress_ReturnsError(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
|
||||
_, cleanup := SetupMockRestoreNode(t)
|
||||
defer cleanup()
|
||||
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: env_config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
},
|
||||
}
|
||||
|
||||
testResp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusOK,
|
||||
)
|
||||
assert.Contains(t, string(testResp.Body), "restore started successfully")
|
||||
|
||||
testResp2 := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp2.Body), "another restore is already in progress")
|
||||
}
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
return CreateTestRouter()
|
||||
}
|
||||
|
||||
func createTestDatabaseWithBackupForRestore(
|
||||
@@ -433,7 +626,7 @@ func createTestMySQLDatabase(
|
||||
token string,
|
||||
router *gin.Engine,
|
||||
) *databases.Database {
|
||||
env := config.GetEnv()
|
||||
env := env_config.GetEnv()
|
||||
portStr := env.TestMysql80Port
|
||||
if portStr == "" {
|
||||
portStr = "33080"
|
||||
@@ -451,7 +644,7 @@ func createTestMySQLDatabase(
|
||||
Type: databases.DatabaseTypeMysql,
|
||||
Mysql: &mysql.MysqlDatabase{
|
||||
Version: tools.MysqlVersion80,
|
||||
Host: "localhost",
|
||||
Host: env_config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package restores
|
||||
package restores_core
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/features/databases/databases/mariadb"
|
||||
@@ -1,4 +1,4 @@
|
||||
package enums
|
||||
package restores_core
|
||||
|
||||
type RestoreStatus string
|
||||
|
||||
@@ -6,4 +6,5 @@ const (
|
||||
RestoreStatusInProgress RestoreStatus = "IN_PROGRESS"
|
||||
RestoreStatusCompleted RestoreStatus = "COMPLETED"
|
||||
RestoreStatusFailed RestoreStatus = "FAILED"
|
||||
RestoreStatusCanceled RestoreStatus = "CANCELED"
|
||||
)
|
||||
23
backend/internal/features/restores/core/interfaces.go
Normal file
23
backend/internal/features/restores/core/interfaces.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package restores_core
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/storages"
|
||||
)
|
||||
|
||||
type RestoreBackupUsecase interface {
|
||||
Execute(
|
||||
ctx context.Context,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore Restore,
|
||||
originalDB *databases.Database,
|
||||
restoringToDB *databases.Database,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
isExcludeExtensions bool,
|
||||
) error
|
||||
}
|
||||
30
backend/internal/features/restores/core/model.go
Normal file
30
backend/internal/features/restores/core/model.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package restores_core
|
||||
|
||||
import (
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/databases/databases/mariadb"
|
||||
"databasus-backend/internal/features/databases/databases/mongodb"
|
||||
"databasus-backend/internal/features/databases/databases/mysql"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type Restore struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
|
||||
Status RestoreStatus `json:"status" gorm:"column:status;type:text;not null"`
|
||||
|
||||
BackupID uuid.UUID `json:"backupId" gorm:"column:backup_id;type:uuid;not null"`
|
||||
Backup *backups_core.Backup
|
||||
|
||||
PostgresqlDatabase *postgresql.PostgresqlDatabase `json:"postgresqlDatabase" gorm:"-"`
|
||||
MysqlDatabase *mysql.MysqlDatabase `json:"mysqlDatabase" gorm:"-"`
|
||||
MariadbDatabase *mariadb.MariadbDatabase `json:"mariadbDatabase" gorm:"-"`
|
||||
MongodbDatabase *mongodb.MongodbDatabase `json:"mongodbDatabase" gorm:"-"`
|
||||
|
||||
FailMessage *string `json:"failMessage" gorm:"column:fail_message"`
|
||||
|
||||
RestoreDurationMs int64 `json:"restoreDurationMs" gorm:"column:restore_duration_ms;default:0"`
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;default:now()"`
|
||||
}
|
||||
91
backend/internal/features/restores/core/repository.go
Normal file
91
backend/internal/features/restores/core/repository.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package restores_core
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/storage"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type RestoreRepository struct{}
|
||||
|
||||
func (r *RestoreRepository) Save(restore *Restore) error {
|
||||
db := storage.GetDb()
|
||||
|
||||
isNew := restore.ID == uuid.Nil
|
||||
if isNew {
|
||||
restore.ID = uuid.New()
|
||||
return db.Create(restore).
|
||||
Omit("Backup", "PostgresqlDatabase", "MysqlDatabase", "MariadbDatabase", "MongodbDatabase").
|
||||
Error
|
||||
}
|
||||
|
||||
return db.Save(restore).
|
||||
Omit("Backup", "PostgresqlDatabase", "MysqlDatabase", "MariadbDatabase", "MongodbDatabase").
|
||||
Error
|
||||
}
|
||||
|
||||
func (r *RestoreRepository) FindByBackupID(backupID uuid.UUID) ([]*Restore, error) {
|
||||
var restores []*Restore
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Backup").
|
||||
Where("backup_id = ?", backupID).
|
||||
Order("created_at DESC").
|
||||
Find(&restores).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return restores, nil
|
||||
}
|
||||
|
||||
func (r *RestoreRepository) FindByID(id uuid.UUID) (*Restore, error) {
|
||||
var restore Restore
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Backup").
|
||||
Where("id = ?", id).
|
||||
First(&restore).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &restore, nil
|
||||
}
|
||||
|
||||
func (r *RestoreRepository) FindByStatus(status RestoreStatus) ([]*Restore, error) {
|
||||
var restores []*Restore
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Backup").
|
||||
Where("status = ?", status).
|
||||
Order("created_at DESC").
|
||||
Find(&restores).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return restores, nil
|
||||
}
|
||||
|
||||
func (r *RestoreRepository) FindInProgressRestoresByDatabaseID(
|
||||
databaseID uuid.UUID,
|
||||
) ([]*Restore, error) {
|
||||
var restores []*Restore
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Backup").
|
||||
Joins("JOIN backups ON backups.id = restores.backup_id").
|
||||
Where("backups.database_id = ? AND restores.status = ?", databaseID, RestoreStatusInProgress).
|
||||
Order("restores.created_at DESC").
|
||||
Find(&restores).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return restores, nil
|
||||
}
|
||||
|
||||
func (r *RestoreRepository) DeleteByID(id uuid.UUID) error {
|
||||
return storage.GetDb().Delete(&Restore{}, "id = ?", id).Error
|
||||
}
|
||||
@@ -1,19 +1,25 @@
|
||||
package restores
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/disk"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/restores/usecases"
|
||||
"databasus-backend/internal/features/storages"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var restoreRepository = &RestoreRepository{}
|
||||
var restoreRepository = &restores_core.RestoreRepository{}
|
||||
var restoreService = &RestoreService{
|
||||
backups.GetBackupService(),
|
||||
restoreRepository,
|
||||
@@ -26,24 +32,32 @@ var restoreService = &RestoreService{
|
||||
audit_logs.GetAuditLogService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
disk.GetDiskService(),
|
||||
tasks_cancellation.GetTaskCancelManager(),
|
||||
}
|
||||
var restoreController = &RestoreController{
|
||||
restoreService,
|
||||
}
|
||||
|
||||
var restoreBackgroundService = &RestoreBackgroundService{
|
||||
restoreRepository,
|
||||
logger.GetLogger(),
|
||||
}
|
||||
|
||||
func GetRestoreController() *RestoreController {
|
||||
return restoreController
|
||||
}
|
||||
|
||||
func GetRestoreBackgroundService() *RestoreBackgroundService {
|
||||
return restoreBackgroundService
|
||||
}
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
backups.GetBackupService().AddBackupRemoveListener(restoreService)
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
backups.GetBackupService().AddBackupRemoveListener(restoreService)
|
||||
backuping.GetBackupCleaner().AddBackupRemoveListener(restoreService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/restores/enums"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type Restore struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
|
||||
Status enums.RestoreStatus `json:"status" gorm:"column:status;type:text;not null"`
|
||||
|
||||
BackupID uuid.UUID `json:"backupId" gorm:"column:backup_id;type:uuid;not null"`
|
||||
Backup *backups_core.Backup
|
||||
|
||||
FailMessage *string `json:"failMessage" gorm:"column:fail_message"`
|
||||
|
||||
RestoreDurationMs int64 `json:"restoreDurationMs" gorm:"column:restore_duration_ms;default:0"`
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;default:now()"`
|
||||
}
|
||||
@@ -1,75 +0,0 @@
|
||||
package restores
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/features/restores/enums"
|
||||
"databasus-backend/internal/features/restores/models"
|
||||
"databasus-backend/internal/storage"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type RestoreRepository struct{}
|
||||
|
||||
func (r *RestoreRepository) Save(restore *models.Restore) error {
|
||||
db := storage.GetDb()
|
||||
|
||||
isNew := restore.ID == uuid.Nil
|
||||
if isNew {
|
||||
restore.ID = uuid.New()
|
||||
return db.Create(restore).
|
||||
Omit("Backup").
|
||||
Error
|
||||
}
|
||||
|
||||
return db.Save(restore).
|
||||
Omit("Backup").
|
||||
Error
|
||||
}
|
||||
|
||||
func (r *RestoreRepository) FindByBackupID(backupID uuid.UUID) ([]*models.Restore, error) {
|
||||
var restores []*models.Restore
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Backup").
|
||||
Where("backup_id = ?", backupID).
|
||||
Order("created_at DESC").
|
||||
Find(&restores).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return restores, nil
|
||||
}
|
||||
|
||||
func (r *RestoreRepository) FindByID(id uuid.UUID) (*models.Restore, error) {
|
||||
var restore models.Restore
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Backup").
|
||||
Where("id = ?", id).
|
||||
First(&restore).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &restore, nil
|
||||
}
|
||||
|
||||
func (r *RestoreRepository) FindByStatus(status enums.RestoreStatus) ([]*models.Restore, error) {
|
||||
var restores []*models.Restore
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Backup").
|
||||
Where("status = ?", status).
|
||||
Order("created_at DESC").
|
||||
Find(&restores).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return restores, nil
|
||||
}
|
||||
|
||||
func (r *RestoreRepository) DeleteByID(id uuid.UUID) error {
|
||||
return storage.GetDb().Delete(&models.Restore{}, "id = ?", id).Error
|
||||
}
|
||||
85
backend/internal/features/restores/restoring/di.go
Normal file
85
backend/internal/features/restores/restoring/di.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package restoring
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/restores/usecases"
|
||||
"databasus-backend/internal/features/storages"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var restoreRepository = &restores_core.RestoreRepository{}
|
||||
|
||||
var restoreNodesRegistry = &RestoreNodesRegistry{
|
||||
client: cache_utils.GetValkeyClient(),
|
||||
logger: logger.GetLogger(),
|
||||
timeout: cache_utils.DefaultCacheTimeout,
|
||||
pubsubRestores: cache_utils.NewPubSubManager(),
|
||||
pubsubCompletions: cache_utils.NewPubSubManager(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
|
||||
var restoreDatabaseCache = cache_utils.NewCacheUtil[RestoreDatabaseCache](
|
||||
cache_utils.GetValkeyClient(),
|
||||
"restore_db:",
|
||||
)
|
||||
|
||||
var restoreCancelManager = tasks_cancellation.GetTaskCancelManager()
|
||||
|
||||
var restorerNode = &RestorerNode{
|
||||
nodeID: uuid.New(),
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
backupService: backups.GetBackupService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
restoreRepository: restoreRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
restoreNodesRegistry: restoreNodesRegistry,
|
||||
logger: logger.GetLogger(),
|
||||
restoreBackupUsecase: usecases.GetRestoreBackupUsecase(),
|
||||
cacheUtil: restoreDatabaseCache,
|
||||
restoreCancelManager: restoreCancelManager,
|
||||
lastHeartbeat: time.Time{},
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
|
||||
var restoresScheduler = &RestoresScheduler{
|
||||
restoreRepository: restoreRepository,
|
||||
backupService: backups.GetBackupService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
restoreNodesRegistry: restoreNodesRegistry,
|
||||
lastCheckTime: time.Now().UTC(),
|
||||
logger: logger.GetLogger(),
|
||||
restoreToNodeRelations: make(map[uuid.UUID]RestoreToNodeRelation),
|
||||
restorerNode: restorerNode,
|
||||
cacheUtil: restoreDatabaseCache,
|
||||
completionSubscriptionID: uuid.Nil,
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
|
||||
func GetRestoresScheduler() *RestoresScheduler {
|
||||
return restoresScheduler
|
||||
}
|
||||
|
||||
func GetRestorerNode() *RestorerNode {
|
||||
return restorerNode
|
||||
}
|
||||
|
||||
func GetRestoreNodesRegistry() *RestoreNodesRegistry {
|
||||
return restoreNodesRegistry
|
||||
}
|
||||
45
backend/internal/features/restores/restoring/dto.go
Normal file
45
backend/internal/features/restores/restoring/dto.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package restoring
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/features/databases/databases/mariadb"
|
||||
"databasus-backend/internal/features/databases/databases/mongodb"
|
||||
"databasus-backend/internal/features/databases/databases/mysql"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type RestoreDatabaseCache struct {
|
||||
PostgresqlDatabase *postgresql.PostgresqlDatabase `json:"postgresqlDatabase,omitempty"`
|
||||
MysqlDatabase *mysql.MysqlDatabase `json:"mysqlDatabase,omitempty"`
|
||||
MariadbDatabase *mariadb.MariadbDatabase `json:"mariadbDatabase,omitempty"`
|
||||
MongodbDatabase *mongodb.MongodbDatabase `json:"mongodbDatabase,omitempty"`
|
||||
}
|
||||
|
||||
type RestoreToNodeRelation struct {
|
||||
NodeID uuid.UUID `json:"nodeId"`
|
||||
RestoreIDs []uuid.UUID `json:"restoreIds"`
|
||||
}
|
||||
|
||||
type RestoreNode struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
ThroughputMBs int `json:"throughputMBs"`
|
||||
LastHeartbeat time.Time `json:"lastHeartbeat"`
|
||||
}
|
||||
|
||||
type RestoreNodeStats struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
ActiveRestores int `json:"activeRestores"`
|
||||
}
|
||||
|
||||
type RestoreSubmitMessage struct {
|
||||
NodeID uuid.UUID `json:"nodeId"`
|
||||
RestoreID uuid.UUID `json:"restoreId"`
|
||||
IsCallNotifier bool `json:"isCallNotifier"`
|
||||
}
|
||||
|
||||
type RestoreCompletionMessage struct {
|
||||
NodeID uuid.UUID `json:"nodeId"`
|
||||
RestoreID uuid.UUID `json:"restoreId"`
|
||||
}
|
||||
88
backend/internal/features/restores/restoring/mocks.go
Normal file
88
backend/internal/features/restores/restoring/mocks.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package restoring
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
)
|
||||
|
||||
type MockSuccessRestoreUsecase struct{}
|
||||
|
||||
func (uc *MockSuccessRestoreUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore restores_core.Restore,
|
||||
originalDB *databases.Database,
|
||||
restoringToDB *databases.Database,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
isExcludeExtensions bool,
|
||||
) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type MockFailedRestoreUsecase struct{}
|
||||
|
||||
func (uc *MockFailedRestoreUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore restores_core.Restore,
|
||||
originalDB *databases.Database,
|
||||
restoringToDB *databases.Database,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
isExcludeExtensions bool,
|
||||
) error {
|
||||
return errors.New("restore failed")
|
||||
}
|
||||
|
||||
type MockCaptureCredentialsRestoreUsecase struct {
|
||||
CalledChan chan *databases.Database
|
||||
ShouldSucceed bool
|
||||
}
|
||||
|
||||
func (uc *MockCaptureCredentialsRestoreUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore restores_core.Restore,
|
||||
originalDB *databases.Database,
|
||||
restoringToDB *databases.Database,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
isExcludeExtensions bool,
|
||||
) error {
|
||||
uc.CalledChan <- restoringToDB
|
||||
|
||||
if uc.ShouldSucceed {
|
||||
return nil
|
||||
}
|
||||
return errors.New("mock restore failed")
|
||||
}
|
||||
|
||||
type MockBlockingRestoreUsecase struct {
|
||||
StartedChan chan bool
|
||||
}
|
||||
|
||||
func (uc *MockBlockingRestoreUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore restores_core.Restore,
|
||||
originalDB *databases.Database,
|
||||
restoringToDB *databases.Database,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
isExcludeExtensions bool,
|
||||
) error {
|
||||
if uc.StartedChan != nil {
|
||||
uc.StartedChan <- true
|
||||
}
|
||||
|
||||
<-ctx.Done()
|
||||
|
||||
return ctx.Err()
|
||||
}
|
||||
649
backend/internal/features/restores/restoring/registry.go
Normal file
649
backend/internal/features/restores/restoring/registry.go
Normal file
@@ -0,0 +1,649 @@
|
||||
package restoring
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/valkey-io/valkey-go"
|
||||
)
|
||||
|
||||
const (
|
||||
nodeInfoKeyPrefix = "restore:node:"
|
||||
nodeInfoKeySuffix = ":info"
|
||||
nodeActiveRestoresPrefix = "restore:node:"
|
||||
nodeActiveRestoresSuffix = ":active_restores"
|
||||
restoreSubmitChannel = "restore:submit"
|
||||
restoreCompletionChannel = "restore:completion"
|
||||
|
||||
deadNodeThreshold = 2 * time.Minute
|
||||
cleanupTickerInterval = 1 * time.Second
|
||||
)
|
||||
|
||||
// RestoreNodesRegistry helps to sync restores scheduler and restore nodes.
|
||||
//
|
||||
// Features:
|
||||
// - Track node availability and load level
|
||||
// - Assign from scheduler to node restores needed to be processed
|
||||
// - Notify scheduler from node about restore completion
|
||||
//
|
||||
// Important things to remember:
|
||||
// - Nodes without heartbeat for more than 2 minutes are not included
|
||||
// in available nodes list and stats
|
||||
//
|
||||
// Cleanup dead nodes performed on 2 levels:
|
||||
// - List and stats functions do not return dead nodes
|
||||
// - Periodically dead nodes are cleaned up in cache (to not
|
||||
// accumulate too many dead nodes in cache)
|
||||
type RestoreNodesRegistry struct {
|
||||
client valkey.Client
|
||||
logger *slog.Logger
|
||||
timeout time.Duration
|
||||
pubsubRestores *cache_utils.PubSubManager
|
||||
pubsubCompletions *cache_utils.PubSubManager
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) Run(ctx context.Context) {
|
||||
wasAlreadyRun := r.hasRun.Load()
|
||||
|
||||
r.runOnce.Do(func() {
|
||||
r.hasRun.Store(true)
|
||||
|
||||
if err := r.cleanupDeadNodes(); err != nil {
|
||||
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(cleanupTickerInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := r.cleanupDeadNodes(); err != nil {
|
||||
r.logger.Error("Failed to cleanup dead nodes", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", r))
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) GetAvailableNodes() ([]RestoreNode, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
var allKeys []string
|
||||
cursor := uint64(0)
|
||||
pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix
|
||||
|
||||
for {
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return nil, fmt.Errorf("failed to scan node keys: %w", result.Error())
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse scan result: %w", err)
|
||||
}
|
||||
|
||||
allKeys = append(allKeys, scanResult.Elements...)
|
||||
|
||||
cursor = scanResult.Cursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(allKeys) == 0 {
|
||||
return []RestoreNode{}, nil
|
||||
}
|
||||
|
||||
keyDataMap, err := r.pipelineGetKeys(allKeys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pipeline get node keys: %w", err)
|
||||
}
|
||||
|
||||
threshold := time.Now().UTC().Add(-deadNodeThreshold)
|
||||
var nodes []RestoreNode
|
||||
|
||||
for key, data := range keyDataMap {
|
||||
// Skip if the key doesn't exist (data is empty)
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var node RestoreNode
|
||||
if err := json.Unmarshal(data, &node); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal node data", "key", key, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip nodes with zero/uninitialized heartbeat
|
||||
if node.LastHeartbeat.IsZero() {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.LastHeartbeat.Before(threshold) {
|
||||
continue
|
||||
}
|
||||
|
||||
nodes = append(nodes, node)
|
||||
}
|
||||
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) GetRestoreNodesStats() ([]RestoreNodeStats, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
var allKeys []string
|
||||
cursor := uint64(0)
|
||||
pattern := nodeActiveRestoresPrefix + "*" + nodeActiveRestoresSuffix
|
||||
|
||||
for {
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(100).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return nil, fmt.Errorf("failed to scan active restores keys: %w", result.Error())
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse scan result: %w", err)
|
||||
}
|
||||
|
||||
allKeys = append(allKeys, scanResult.Elements...)
|
||||
|
||||
cursor = scanResult.Cursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(allKeys) == 0 {
|
||||
return []RestoreNodeStats{}, nil
|
||||
}
|
||||
|
||||
keyDataMap, err := r.pipelineGetKeys(allKeys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pipeline get active restores keys: %w", err)
|
||||
}
|
||||
|
||||
var nodeInfoKeys []string
|
||||
nodeIDToStatsKey := make(map[string]string)
|
||||
for key := range keyDataMap {
|
||||
nodeID := r.extractNodeIDFromKey(key, nodeActiveRestoresPrefix, nodeActiveRestoresSuffix)
|
||||
nodeIDStr := nodeID.String()
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeIDStr, nodeInfoKeySuffix)
|
||||
nodeInfoKeys = append(nodeInfoKeys, infoKey)
|
||||
nodeIDToStatsKey[infoKey] = key
|
||||
}
|
||||
|
||||
nodeInfoMap, err := r.pipelineGetKeys(nodeInfoKeys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pipeline get node info keys: %w", err)
|
||||
}
|
||||
|
||||
threshold := time.Now().UTC().Add(-deadNodeThreshold)
|
||||
var stats []RestoreNodeStats
|
||||
for infoKey, nodeData := range nodeInfoMap {
|
||||
// Skip if the info key doesn't exist (nodeData is empty)
|
||||
if len(nodeData) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var node RestoreNode
|
||||
if err := json.Unmarshal(nodeData, &node); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal node data", "key", infoKey, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip nodes with zero/uninitialized heartbeat
|
||||
if node.LastHeartbeat.IsZero() {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.LastHeartbeat.Before(threshold) {
|
||||
continue
|
||||
}
|
||||
|
||||
statsKey := nodeIDToStatsKey[infoKey]
|
||||
tasksData := keyDataMap[statsKey]
|
||||
count, err := r.parseIntFromBytes(tasksData)
|
||||
if err != nil {
|
||||
r.logger.Warn("Failed to parse active restores count", "key", statsKey, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
stat := RestoreNodeStats{
|
||||
ID: node.ID,
|
||||
ActiveRestores: int(count),
|
||||
}
|
||||
stats = append(stats, stat)
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) IncrementRestoresInProgress(nodeID uuid.UUID) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf(
|
||||
"%s%s%s",
|
||||
nodeActiveRestoresPrefix,
|
||||
nodeID.String(),
|
||||
nodeActiveRestoresSuffix,
|
||||
)
|
||||
result := r.client.Do(ctx, r.client.B().Incr().Key(key).Build())
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to increment restores in progress for node %s: %w",
|
||||
nodeID,
|
||||
result.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) DecrementRestoresInProgress(nodeID uuid.UUID) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf(
|
||||
"%s%s%s",
|
||||
nodeActiveRestoresPrefix,
|
||||
nodeID.String(),
|
||||
nodeActiveRestoresSuffix,
|
||||
)
|
||||
result := r.client.Do(ctx, r.client.B().Decr().Key(key).Build())
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to decrement restores in progress for node %s: %w",
|
||||
nodeID,
|
||||
result.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
newValue, err := result.AsInt64()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse decremented value for node %s: %w", nodeID, err)
|
||||
}
|
||||
|
||||
if newValue < 0 {
|
||||
setCtx, setCancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
r.client.Do(setCtx, r.client.B().Set().Key(key).Value("0").Build())
|
||||
setCancel()
|
||||
r.logger.Warn("Active restores counter went below 0, reset to 0", "nodeID", nodeID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) HearthbeatNodeInRegistry(
|
||||
now time.Time,
|
||||
restoreNode RestoreNode,
|
||||
) error {
|
||||
if now.IsZero() {
|
||||
return fmt.Errorf("cannot register node with zero heartbeat timestamp")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
restoreNode.LastHeartbeat = now
|
||||
|
||||
data, err := json.Marshal(restoreNode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal restore node: %w", err)
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, restoreNode.ID.String(), nodeInfoKeySuffix)
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Set().Key(key).Value(string(data)).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to register node %s: %w", restoreNode.ID, result.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) UnregisterNodeFromRegistry(restoreNode RestoreNode) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, restoreNode.ID.String(), nodeInfoKeySuffix)
|
||||
counterKey := fmt.Sprintf(
|
||||
"%s%s%s",
|
||||
nodeActiveRestoresPrefix,
|
||||
restoreNode.ID.String(),
|
||||
nodeActiveRestoresSuffix,
|
||||
)
|
||||
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Del().Key(infoKey, counterKey).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to unregister node %s: %w", restoreNode.ID, result.Error())
|
||||
}
|
||||
|
||||
r.logger.Info("Unregistered node from registry", "nodeID", restoreNode.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) AssignRestoreToNode(
|
||||
targetNodeID uuid.UUID,
|
||||
restoreID uuid.UUID,
|
||||
isCallNotifier bool,
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
message := RestoreSubmitMessage{
|
||||
NodeID: targetNodeID,
|
||||
RestoreID: restoreID,
|
||||
IsCallNotifier: isCallNotifier,
|
||||
}
|
||||
|
||||
messageJSON, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal restore submit message: %w", err)
|
||||
}
|
||||
|
||||
err = r.pubsubRestores.Publish(ctx, restoreSubmitChannel, string(messageJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to publish restore submit message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) SubscribeNodeForRestoresAssignment(
|
||||
nodeID uuid.UUID,
|
||||
handler func(restoreID uuid.UUID, isCallNotifier bool),
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
wrappedHandler := func(message string) {
|
||||
var msg RestoreSubmitMessage
|
||||
if err := json.Unmarshal([]byte(message), &msg); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal restore submit message", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if msg.NodeID != nodeID {
|
||||
return
|
||||
}
|
||||
|
||||
handler(msg.RestoreID, msg.IsCallNotifier)
|
||||
}
|
||||
|
||||
err := r.pubsubRestores.Subscribe(ctx, restoreSubmitChannel, wrappedHandler)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to restore submit channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Subscribed to restore submit channel", "nodeID", nodeID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) UnsubscribeNodeForRestoresAssignments() error {
|
||||
err := r.pubsubRestores.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unsubscribe from restore submit channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Unsubscribed from restore submit channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) PublishRestoreCompletion(
|
||||
nodeID uuid.UUID,
|
||||
restoreID uuid.UUID,
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
message := RestoreCompletionMessage{
|
||||
NodeID: nodeID,
|
||||
RestoreID: restoreID,
|
||||
}
|
||||
|
||||
messageJSON, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal restore completion message: %w", err)
|
||||
}
|
||||
|
||||
err = r.pubsubCompletions.Publish(ctx, restoreCompletionChannel, string(messageJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to publish restore completion message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) SubscribeForRestoresCompletions(
|
||||
handler func(nodeID uuid.UUID, restoreID uuid.UUID),
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
wrappedHandler := func(message string) {
|
||||
var msg RestoreCompletionMessage
|
||||
if err := json.Unmarshal([]byte(message), &msg); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal restore completion message", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
handler(msg.NodeID, msg.RestoreID)
|
||||
}
|
||||
|
||||
err := r.pubsubCompletions.Subscribe(ctx, restoreCompletionChannel, wrappedHandler)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to restore completion channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Subscribed to restore completion channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) UnsubscribeForRestoresCompletions() error {
|
||||
err := r.pubsubCompletions.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unsubscribe from restore completion channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Unsubscribed from restore completion channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) extractNodeIDFromKey(key, prefix, suffix string) uuid.UUID {
|
||||
nodeIDStr := strings.TrimPrefix(key, prefix)
|
||||
nodeIDStr = strings.TrimSuffix(nodeIDStr, suffix)
|
||||
|
||||
nodeID, err := uuid.Parse(nodeIDStr)
|
||||
if err != nil {
|
||||
r.logger.Warn("Failed to parse node ID from key", "key", key, "error", err)
|
||||
return uuid.Nil
|
||||
}
|
||||
|
||||
return nodeID
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) pipelineGetKeys(keys []string) (map[string][]byte, error) {
|
||||
if len(keys) == 0 {
|
||||
return make(map[string][]byte), nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
commands := make([]valkey.Completed, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
commands = append(commands, r.client.B().Get().Key(key).Build())
|
||||
}
|
||||
|
||||
results := r.client.DoMulti(ctx, commands...)
|
||||
|
||||
keyDataMap := make(map[string][]byte, len(keys))
|
||||
for i, result := range results {
|
||||
if result.Error() != nil {
|
||||
r.logger.Warn("Failed to get key in pipeline", "key", keys[i], "error", result.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
data, err := result.AsBytes()
|
||||
if err != nil {
|
||||
r.logger.Warn("Failed to parse key data in pipeline", "key", keys[i], "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
keyDataMap[keys[i]] = data
|
||||
}
|
||||
|
||||
return keyDataMap, nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) parseIntFromBytes(data []byte) (int64, error) {
|
||||
str := string(data)
|
||||
var count int64
|
||||
_, err := fmt.Sscanf(str, "%d", &count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to parse integer from bytes: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) cleanupDeadNodes() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
var allKeys []string
|
||||
cursor := uint64(0)
|
||||
pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix
|
||||
|
||||
for {
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to scan node keys: %w", result.Error())
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse scan result: %w", err)
|
||||
}
|
||||
|
||||
allKeys = append(allKeys, scanResult.Elements...)
|
||||
|
||||
cursor = scanResult.Cursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(allKeys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
keyDataMap, err := r.pipelineGetKeys(allKeys)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to pipeline get node keys: %w", err)
|
||||
}
|
||||
|
||||
threshold := time.Now().UTC().Add(-deadNodeThreshold)
|
||||
var deadNodeKeys []string
|
||||
|
||||
for key, data := range keyDataMap {
|
||||
// Skip if the key doesn't exist (data is empty)
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var node RestoreNode
|
||||
if err := json.Unmarshal(data, &node); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal node data during cleanup", "key", key, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip nodes with zero/uninitialized heartbeat
|
||||
if node.LastHeartbeat.IsZero() {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.LastHeartbeat.Before(threshold) {
|
||||
nodeID := node.ID.String()
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeID, nodeInfoKeySuffix)
|
||||
statsKey := fmt.Sprintf(
|
||||
"%s%s%s",
|
||||
nodeActiveRestoresPrefix,
|
||||
nodeID,
|
||||
nodeActiveRestoresSuffix,
|
||||
)
|
||||
|
||||
deadNodeKeys = append(deadNodeKeys, infoKey, statsKey)
|
||||
r.logger.Info(
|
||||
"Marking node for cleanup",
|
||||
"nodeID", nodeID,
|
||||
"lastHeartbeat", node.LastHeartbeat,
|
||||
"threshold", threshold,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if len(deadNodeKeys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
delCtx, delCancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer delCancel()
|
||||
|
||||
result := r.client.Do(
|
||||
delCtx,
|
||||
r.client.B().Del().Key(deadNodeKeys...).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to delete dead node keys: %w", result.Error())
|
||||
}
|
||||
|
||||
deletedCount, err := result.AsInt64()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse deleted count: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Cleaned up dead nodes", "deletedKeysCount", deletedCount)
|
||||
return nil
|
||||
}
|
||||
1136
backend/internal/features/restores/restoring/registry_test.go
Normal file
1136
backend/internal/features/restores/restoring/registry_test.go
Normal file
File diff suppressed because it is too large
Load Diff
310
backend/internal/features/restores/restoring/restorer.go
Normal file
310
backend/internal/features/restores/restoring/restorer.go
Normal file
@@ -0,0 +1,310 @@
|
||||
package restoring
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
)
|
||||
|
||||
const (
|
||||
heartbeatTickerInterval = 15 * time.Second
|
||||
restorerHealthcheckThreshold = 5 * time.Minute
|
||||
)
|
||||
|
||||
type RestorerNode struct {
|
||||
nodeID uuid.UUID
|
||||
|
||||
databaseService *databases.DatabaseService
|
||||
backupService *backups.BackupService
|
||||
fieldEncryptor util_encryption.FieldEncryptor
|
||||
restoreRepository *restores_core.RestoreRepository
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
storageService *storages.StorageService
|
||||
restoreNodesRegistry *RestoreNodesRegistry
|
||||
logger *slog.Logger
|
||||
restoreBackupUsecase restores_core.RestoreBackupUsecase
|
||||
cacheUtil *cache_utils.CacheUtil[RestoreDatabaseCache]
|
||||
restoreCancelManager *tasks_cancellation.TaskCancelManager
|
||||
|
||||
lastHeartbeat time.Time
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (n *RestorerNode) Run(ctx context.Context) {
|
||||
wasAlreadyRun := n.hasRun.Load()
|
||||
|
||||
n.runOnce.Do(func() {
|
||||
n.hasRun.Store(true)
|
||||
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
|
||||
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
|
||||
|
||||
restoreNode := RestoreNode{
|
||||
ID: n.nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
}
|
||||
|
||||
if err := n.restoreNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), restoreNode); err != nil {
|
||||
n.logger.Error("Failed to register node in registry", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
restoreHandler := func(restoreID uuid.UUID, isCallNotifier bool) {
|
||||
n.MakeRestore(restoreID)
|
||||
if err := n.restoreNodesRegistry.PublishRestoreCompletion(n.nodeID, restoreID); err != nil {
|
||||
n.logger.Error(
|
||||
"Failed to publish restore completion",
|
||||
"error",
|
||||
err,
|
||||
"restoreID",
|
||||
restoreID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
err := n.restoreNodesRegistry.SubscribeNodeForRestoresAssignment(
|
||||
n.nodeID,
|
||||
restoreHandler,
|
||||
)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to subscribe to restore assignments", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := n.restoreNodesRegistry.UnsubscribeNodeForRestoresAssignments(); err != nil {
|
||||
n.logger.Error("Failed to unsubscribe from restore assignments", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(heartbeatTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
n.logger.Info("Restore node started", "nodeID", n.nodeID, "throughput", throughputMBs)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
|
||||
|
||||
if err := n.restoreNodesRegistry.UnregisterNodeFromRegistry(restoreNode); err != nil {
|
||||
n.logger.Error("Failed to unregister node from registry", "error", err)
|
||||
}
|
||||
|
||||
return
|
||||
case <-ticker.C:
|
||||
n.sendHeartbeat(&restoreNode)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", n))
|
||||
}
|
||||
}
|
||||
|
||||
func (n *RestorerNode) IsRestorerRunning() bool {
|
||||
return n.lastHeartbeat.After(time.Now().UTC().Add(-restorerHealthcheckThreshold))
|
||||
}
|
||||
|
||||
func (n *RestorerNode) MakeRestore(restoreID uuid.UUID) {
|
||||
// Get and delete cached DB credentials atomically
|
||||
dbCache := n.cacheUtil.GetAndDelete(restoreID.String())
|
||||
|
||||
if dbCache == nil {
|
||||
// Cache miss - fail immediately
|
||||
restore, err := n.restoreRepository.FindByID(restoreID)
|
||||
if err != nil {
|
||||
n.logger.Error(
|
||||
"Failed to get restore by ID after cache miss",
|
||||
"restoreId",
|
||||
restoreID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
errMsg := "Database credentials expired or missing from cache (most likely due to instance restart)"
|
||||
restore.FailMessage = &errMsg
|
||||
restore.Status = restores_core.RestoreStatusFailed
|
||||
|
||||
if err := n.restoreRepository.Save(restore); err != nil {
|
||||
n.logger.Error("Failed to save restore after cache miss", "error", err)
|
||||
}
|
||||
|
||||
n.logger.Error("Restore failed: cache miss", "restoreId", restoreID)
|
||||
return
|
||||
}
|
||||
|
||||
restore, err := n.restoreRepository.FindByID(restoreID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get restore by ID", "restoreId", restoreID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
backup, err := n.backupService.GetBackup(restore.BackupID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get backup by ID", "backupId", restore.BackupID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
databaseID := backup.DatabaseID
|
||||
|
||||
database, err := n.databaseService.GetDatabaseByID(databaseID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get database by ID", "databaseId", databaseID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
backupConfig, err := n.backupConfigService.GetBackupConfigByDbId(databaseID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if backupConfig.StorageID == nil {
|
||||
n.logger.Error("Backup config storage ID is not defined")
|
||||
return
|
||||
}
|
||||
|
||||
storage, err := n.storageService.GetStorageByID(*backupConfig.StorageID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get storage by ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now().UTC()
|
||||
|
||||
// Create cancellable context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
n.restoreCancelManager.RegisterTask(restore.ID, cancel)
|
||||
defer n.restoreCancelManager.UnregisterTask(restore.ID)
|
||||
|
||||
// Create restoring database from cached credentials
|
||||
restoringToDB := &databases.Database{
|
||||
Type: database.Type,
|
||||
Postgresql: dbCache.PostgresqlDatabase,
|
||||
Mysql: dbCache.MysqlDatabase,
|
||||
Mariadb: dbCache.MariadbDatabase,
|
||||
Mongodb: dbCache.MongodbDatabase,
|
||||
}
|
||||
|
||||
if err := restoringToDB.PopulateDbData(n.logger, n.fieldEncryptor); err != nil {
|
||||
errMsg := fmt.Sprintf("failed to auto-detect database data: %v", err)
|
||||
restore.FailMessage = &errMsg
|
||||
restore.Status = restores_core.RestoreStatusFailed
|
||||
restore.RestoreDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
if err := n.restoreRepository.Save(restore); err != nil {
|
||||
n.logger.Error("Failed to save restore", "error", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
isExcludeExtensions := false
|
||||
if dbCache.PostgresqlDatabase != nil {
|
||||
isExcludeExtensions = dbCache.PostgresqlDatabase.IsExcludeExtensions
|
||||
}
|
||||
|
||||
err = n.restoreBackupUsecase.Execute(
|
||||
ctx,
|
||||
backupConfig,
|
||||
*restore,
|
||||
database,
|
||||
restoringToDB,
|
||||
backup,
|
||||
storage,
|
||||
isExcludeExtensions,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
|
||||
// Check if restore was cancelled
|
||||
isCancelled := strings.Contains(errMsg, "restore cancelled") ||
|
||||
strings.Contains(errMsg, "context canceled") ||
|
||||
errors.Is(err, context.Canceled)
|
||||
isShutdown := strings.Contains(errMsg, "shutdown")
|
||||
|
||||
if isCancelled && !isShutdown {
|
||||
n.logger.Warn("Restore was cancelled by user or system",
|
||||
"restoreId", restore.ID,
|
||||
"isCancelled", isCancelled,
|
||||
"isShutdown", isShutdown,
|
||||
)
|
||||
|
||||
restore.Status = restores_core.RestoreStatusCanceled
|
||||
restore.RestoreDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
if err := n.restoreRepository.Save(restore); err != nil {
|
||||
n.logger.Error("Failed to save cancelled restore", "error", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
n.logger.Error("Restore execution failed",
|
||||
"restoreId", restore.ID,
|
||||
"backupId", backup.ID,
|
||||
"databaseId", databaseID,
|
||||
"databaseType", database.Type,
|
||||
"storageId", storage.ID,
|
||||
"storageType", storage.Type,
|
||||
"error", err,
|
||||
"errorMessage", errMsg,
|
||||
)
|
||||
|
||||
restore.FailMessage = &errMsg
|
||||
restore.Status = restores_core.RestoreStatusFailed
|
||||
restore.RestoreDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
if err := n.restoreRepository.Save(restore); err != nil {
|
||||
n.logger.Error("Failed to save restore", "error", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
restore.Status = restores_core.RestoreStatusCompleted
|
||||
restore.RestoreDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
if err := n.restoreRepository.Save(restore); err != nil {
|
||||
n.logger.Error("Failed to save restore", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
n.logger.Info(
|
||||
"Restore completed successfully",
|
||||
"restoreId", restore.ID,
|
||||
"backupId", backup.ID,
|
||||
"durationMs", restore.RestoreDurationMs,
|
||||
)
|
||||
}
|
||||
|
||||
func (n *RestorerNode) sendHeartbeat(restoreNode *RestoreNode) {
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
if err := n.restoreNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *restoreNode); err != nil {
|
||||
n.logger.Error("Failed to send heartbeat", "error", err)
|
||||
}
|
||||
}
|
||||
164
backend/internal/features/restores/restoring/restorer_test.go
Normal file
164
backend/internal/features/restores/restoring/restorer_test.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package restoring
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
)
|
||||
|
||||
func Test_MakeRestore_WhenCacheMissed_RestoreFails(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
defer func() {
|
||||
backupRepo := backups_core.BackupRepository{}
|
||||
backupsList, _ := backupRepo.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backupsList {
|
||||
backupRepo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
restoreRepo := restores_core.RestoreRepository{}
|
||||
restoresInProgress, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
|
||||
for _, restore := range restoresInProgress {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
restoresFailed, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusFailed)
|
||||
for _, restore := range restoresFailed {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
||||
|
||||
// Create restore but DON'T cache DB credentials
|
||||
// Also don't set embedded DB fields to avoid schema issues
|
||||
restore := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
}
|
||||
err := restoreRepository.Save(restore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create restorer and execute restore (should fail due to cache miss)
|
||||
restorerNode := CreateTestRestorerNode()
|
||||
restorerNode.MakeRestore(restore.ID)
|
||||
|
||||
// Verify restore failed with appropriate error message
|
||||
updatedRestore, err := restoreRepository.FindByID(restore.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, restores_core.RestoreStatusFailed, updatedRestore.Status)
|
||||
assert.NotNil(t, updatedRestore.FailMessage)
|
||||
assert.Contains(
|
||||
t,
|
||||
*updatedRestore.FailMessage,
|
||||
"Database credentials expired or missing from cache",
|
||||
)
|
||||
}
|
||||
|
||||
func Test_MakeRestore_WhenTaskStarts_CacheDeletedImmediately(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
defer func() {
|
||||
backupRepo := backups_core.BackupRepository{}
|
||||
backupsList, _ := backupRepo.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backupsList {
|
||||
backupRepo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
restoreRepo := restores_core.RestoreRepository{}
|
||||
restoresInProgress, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
|
||||
for _, restore := range restoresInProgress {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
restoresFailed, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusFailed)
|
||||
for _, restore := range restoresFailed {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
restoresCompleted, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusCompleted)
|
||||
for _, restore := range restoresCompleted {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
||||
|
||||
// Create restore with cached DB credentials
|
||||
// Don't set embedded DB fields in the restore model itself
|
||||
restore := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
}
|
||||
err := restoreRepository.Save(restore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Cache DB credentials separately
|
||||
dbCache := &RestoreDatabaseCache{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "test",
|
||||
Password: "test",
|
||||
Database: stringPtr("testdb"),
|
||||
Version: "16",
|
||||
},
|
||||
}
|
||||
restoreDatabaseCache.SetWithExpiration(restore.ID.String(), dbCache, 1*time.Hour)
|
||||
|
||||
// Verify cache exists before restore starts
|
||||
cachedDB := restoreDatabaseCache.Get(restore.ID.String())
|
||||
assert.NotNil(t, cachedDB, "Cache should exist before restore starts")
|
||||
|
||||
// Start restore (this will call GetAndDelete)
|
||||
restorerNode := CreateTestRestorerNode()
|
||||
restorerNode.MakeRestore(restore.ID)
|
||||
|
||||
// Verify cache was deleted immediately
|
||||
cachedDBAfter := restoreDatabaseCache.Get(restore.ID.String())
|
||||
assert.Nil(t, cachedDBAfter, "Cache should be deleted immediately when task starts")
|
||||
}
|
||||
410
backend/internal/features/restores/restoring/scheduler.go
Normal file
410
backend/internal/features/restores/restoring/scheduler.go
Normal file
@@ -0,0 +1,410 @@
|
||||
package restoring
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
)
|
||||
|
||||
const (
|
||||
schedulerStartupDelay = 1 * time.Minute
|
||||
schedulerTickerInterval = 1 * time.Minute
|
||||
schedulerHealthcheckThreshold = 5 * time.Minute
|
||||
)
|
||||
|
||||
type RestoresScheduler struct {
|
||||
restoreRepository *restores_core.RestoreRepository
|
||||
backupService *backups.BackupService
|
||||
storageService *storages.StorageService
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
restoreNodesRegistry *RestoreNodesRegistry
|
||||
lastCheckTime time.Time
|
||||
logger *slog.Logger
|
||||
restoreToNodeRelations map[uuid.UUID]RestoreToNodeRelation
|
||||
restorerNode *RestorerNode
|
||||
cacheUtil *cache_utils.CacheUtil[RestoreDatabaseCache]
|
||||
completionSubscriptionID uuid.UUID
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *RestoresScheduler) Run(ctx context.Context) {
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
|
||||
s.lastCheckTime = time.Now().UTC()
|
||||
|
||||
if config.GetEnv().IsManyNodesMode {
|
||||
// wait other nodes to start
|
||||
time.Sleep(schedulerStartupDelay)
|
||||
}
|
||||
|
||||
if err := s.failRestoresInProgress(); err != nil {
|
||||
s.logger.Error("Failed to fail restores in progress", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err := s.restoreNodesRegistry.SubscribeForRestoresCompletions(s.onRestoreCompleted)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to subscribe to restore completions", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := s.restoreNodesRegistry.UnsubscribeForRestoresCompletions(); err != nil {
|
||||
s.logger.Error("Failed to unsubscribe from restore completions", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(schedulerTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.checkDeadNodesAndFailRestores(); err != nil {
|
||||
s.logger.Error("Failed to check dead nodes and fail restores", "error", err)
|
||||
}
|
||||
|
||||
s.lastCheckTime = time.Now().UTC()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RestoresScheduler) IsSchedulerRunning() bool {
|
||||
return s.lastCheckTime.After(time.Now().UTC().Add(-schedulerHealthcheckThreshold))
|
||||
}
|
||||
|
||||
func (s *RestoresScheduler) failRestoresInProgress() error {
|
||||
restoresInProgress, err := s.restoreRepository.FindByStatus(
|
||||
restores_core.RestoreStatusInProgress,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, restore := range restoresInProgress {
|
||||
failMessage := "Restore failed due to application restart"
|
||||
restore.FailMessage = &failMessage
|
||||
restore.Status = restores_core.RestoreStatusFailed
|
||||
|
||||
if err := s.restoreRepository.Save(restore); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RestoresScheduler) StartRestore(restoreID uuid.UUID, dbCache *RestoreDatabaseCache) error {
|
||||
// If dbCache not provided, try to fetch from DB (for backward compatibility/testing)
|
||||
if dbCache == nil {
|
||||
restore, err := s.restoreRepository.FindByID(restoreID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to find restore by ID",
|
||||
"restoreId",
|
||||
restoreID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// Create cache DTO from restore (may be nil if not in DB)
|
||||
dbCache = &RestoreDatabaseCache{
|
||||
PostgresqlDatabase: restore.PostgresqlDatabase,
|
||||
MysqlDatabase: restore.MysqlDatabase,
|
||||
MariadbDatabase: restore.MariadbDatabase,
|
||||
MongodbDatabase: restore.MongodbDatabase,
|
||||
}
|
||||
}
|
||||
|
||||
// Cache database credentials with 1-hour expiration
|
||||
s.cacheUtil.SetWithExpiration(restoreID.String(), dbCache, 1*time.Hour)
|
||||
|
||||
leastBusyNodeID, err := s.calculateLeastBusyNode()
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to calculate least busy node",
|
||||
"restoreId",
|
||||
restoreID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.restoreNodesRegistry.IncrementRestoresInProgress(*leastBusyNodeID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to increment restores in progress",
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
"restoreId",
|
||||
restoreID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.restoreNodesRegistry.AssignRestoreToNode(*leastBusyNodeID, restoreID, false); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to submit restore",
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
"restoreId",
|
||||
restoreID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
if decrementErr := s.restoreNodesRegistry.DecrementRestoresInProgress(*leastBusyNodeID); decrementErr != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement restores in progress after submit failure",
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
"error",
|
||||
decrementErr,
|
||||
)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if relation, exists := s.restoreToNodeRelations[*leastBusyNodeID]; exists {
|
||||
relation.RestoreIDs = append(relation.RestoreIDs, restoreID)
|
||||
s.restoreToNodeRelations[*leastBusyNodeID] = relation
|
||||
} else {
|
||||
s.restoreToNodeRelations[*leastBusyNodeID] = RestoreToNodeRelation{
|
||||
NodeID: *leastBusyNodeID,
|
||||
RestoreIDs: []uuid.UUID{restoreID},
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Info(
|
||||
"Successfully triggered restore",
|
||||
"restoreId",
|
||||
restoreID,
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RestoresScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
|
||||
nodes, err := s.restoreNodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get available nodes: %w", err)
|
||||
}
|
||||
|
||||
if len(nodes) == 0 {
|
||||
return nil, fmt.Errorf("no nodes available")
|
||||
}
|
||||
|
||||
stats, err := s.restoreNodesRegistry.GetRestoreNodesStats()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get restore nodes stats: %w", err)
|
||||
}
|
||||
|
||||
statsMap := make(map[uuid.UUID]int)
|
||||
for _, stat := range stats {
|
||||
statsMap[stat.ID] = stat.ActiveRestores
|
||||
}
|
||||
|
||||
var bestNode *RestoreNode
|
||||
var bestScore float64 = -1
|
||||
|
||||
for i := range nodes {
|
||||
node := &nodes[i]
|
||||
|
||||
activeRestores := statsMap[node.ID]
|
||||
|
||||
var score float64
|
||||
if node.ThroughputMBs > 0 {
|
||||
score = float64(activeRestores) / float64(node.ThroughputMBs)
|
||||
} else {
|
||||
score = float64(activeRestores) * 1000
|
||||
}
|
||||
|
||||
if bestNode == nil || score < bestScore {
|
||||
bestNode = node
|
||||
bestScore = score
|
||||
}
|
||||
}
|
||||
|
||||
if bestNode == nil {
|
||||
return nil, fmt.Errorf("no suitable nodes available")
|
||||
}
|
||||
|
||||
return &bestNode.ID, nil
|
||||
}
|
||||
|
||||
func (s *RestoresScheduler) onRestoreCompleted(nodeID uuid.UUID, restoreID uuid.UUID) {
|
||||
// Verify this task is actually a restore (registry contains multiple task types)
|
||||
_, err := s.restoreRepository.FindByID(restoreID)
|
||||
if err != nil {
|
||||
// Not a restore task, ignore it
|
||||
return
|
||||
}
|
||||
|
||||
relation, exists := s.restoreToNodeRelations[nodeID]
|
||||
if !exists {
|
||||
s.logger.Warn(
|
||||
"Received completion for unknown node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"restoreId",
|
||||
restoreID,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
newRestoreIDs := make([]uuid.UUID, 0)
|
||||
found := false
|
||||
for _, id := range relation.RestoreIDs {
|
||||
if id == restoreID {
|
||||
found = true
|
||||
continue
|
||||
}
|
||||
newRestoreIDs = append(newRestoreIDs, id)
|
||||
}
|
||||
|
||||
if !found {
|
||||
s.logger.Warn(
|
||||
"Restore not found in node's restore list",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"restoreId",
|
||||
restoreID,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if len(newRestoreIDs) == 0 {
|
||||
delete(s.restoreToNodeRelations, nodeID)
|
||||
} else {
|
||||
relation.RestoreIDs = newRestoreIDs
|
||||
s.restoreToNodeRelations[nodeID] = relation
|
||||
}
|
||||
|
||||
if err := s.restoreNodesRegistry.DecrementRestoresInProgress(nodeID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement restores in progress",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"restoreId",
|
||||
restoreID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RestoresScheduler) checkDeadNodesAndFailRestores() error {
|
||||
nodes, err := s.restoreNodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get available nodes: %w", err)
|
||||
}
|
||||
|
||||
aliveNodeIDs := make(map[uuid.UUID]bool)
|
||||
for _, node := range nodes {
|
||||
aliveNodeIDs[node.ID] = true
|
||||
}
|
||||
|
||||
for nodeID, relation := range s.restoreToNodeRelations {
|
||||
if aliveNodeIDs[nodeID] {
|
||||
continue
|
||||
}
|
||||
|
||||
s.logger.Warn(
|
||||
"Node is dead, failing its restores",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"restoreCount",
|
||||
len(relation.RestoreIDs),
|
||||
)
|
||||
|
||||
for _, restoreID := range relation.RestoreIDs {
|
||||
restore, err := s.restoreRepository.FindByID(restoreID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to find restore for dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"restoreId",
|
||||
restoreID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
failMessage := "Restore failed due to node unavailability"
|
||||
restore.FailMessage = &failMessage
|
||||
restore.Status = restores_core.RestoreStatusFailed
|
||||
|
||||
if err := s.restoreRepository.Save(restore); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to save failed restore for dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"restoreId",
|
||||
restoreID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := s.restoreNodesRegistry.DecrementRestoresInProgress(nodeID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement restores in progress for dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"restoreId",
|
||||
restoreID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
s.logger.Info(
|
||||
"Failed restore due to dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"restoreId",
|
||||
restoreID,
|
||||
)
|
||||
}
|
||||
|
||||
delete(s.restoreToNodeRelations, nodeID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
856
backend/internal/features/restores/restoring/scheduler_test.go
Normal file
856
backend/internal/features/restores/restoring/scheduler_test.go
Normal file
@@ -0,0 +1,856 @@
|
||||
package restoring
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_CheckDeadNodesAndFailRestores_NodeDies_FailsRestoreAndCleansUpRegistry(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
var mockNodeID uuid.UUID
|
||||
|
||||
defer func() {
|
||||
backupRepo := backups_core.BackupRepository{}
|
||||
backups, _ := backupRepo.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
restoreRepo := restores_core.RestoreRepository{}
|
||||
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
restores, _ = restoreRepo.FindByStatus(restores_core.RestoreStatusFailed)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
// Clean up mock node
|
||||
if mockNodeID != uuid.Nil {
|
||||
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: mockNodeID})
|
||||
}
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
// Create a test backup
|
||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
||||
|
||||
var err error
|
||||
// Register mock node without subscribing to restores (simulates node crash after registration)
|
||||
mockNodeID = uuid.New()
|
||||
err = CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC())
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create restore and assign to mock node
|
||||
restore := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
}
|
||||
err = restoreRepository.Save(restore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Scheduler assigns restore to mock node
|
||||
err = GetRestoresScheduler().StartRestore(restore.ID, nil)
|
||||
assert.NoError(t, err)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify Valkey counter was incremented when restore was assigned
|
||||
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
|
||||
assert.NoError(t, err)
|
||||
foundStat := false
|
||||
for _, stat := range stats {
|
||||
if stat.ID == mockNodeID {
|
||||
assert.Equal(t, 1, stat.ActiveRestores)
|
||||
foundStat = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, foundStat, "Node stats should be present")
|
||||
|
||||
// Simulate node death by setting heartbeat older than 2-minute threshold
|
||||
oldHeartbeat := time.Now().UTC().Add(-3 * time.Minute)
|
||||
err = UpdateNodeHeartbeatDirectly(mockNodeID, 100, oldHeartbeat)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Trigger dead node detection
|
||||
err = GetRestoresScheduler().checkDeadNodesAndFailRestores()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify restore was failed with appropriate error message
|
||||
failedRestore, err := restoreRepository.FindByID(restore.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, restores_core.RestoreStatusFailed, failedRestore.Status)
|
||||
assert.NotNil(t, failedRestore.FailMessage)
|
||||
assert.Contains(t, *failedRestore.FailMessage, "node unavailability")
|
||||
|
||||
// Verify Valkey counter was decremented after restore failed
|
||||
stats, err = restoreNodesRegistry.GetRestoreNodesStats()
|
||||
assert.NoError(t, err)
|
||||
for _, stat := range stats {
|
||||
if stat.ID == mockNodeID {
|
||||
assert.Equal(t, 0, stat.ActiveRestores)
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_OnRestoreCompleted_TaskIsNotRestore_SkipsProcessing(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
var mockNodeID uuid.UUID
|
||||
|
||||
defer func() {
|
||||
backupRepo := backups_core.BackupRepository{}
|
||||
backups, _ := backupRepo.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
restoreRepo := restores_core.RestoreRepository{}
|
||||
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
// Clean up mock node
|
||||
if mockNodeID != uuid.Nil {
|
||||
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: mockNodeID})
|
||||
}
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
// Create a test backup
|
||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
||||
|
||||
// Register mock node
|
||||
mockNodeID = uuid.New()
|
||||
err := CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC())
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create restore and assign to the node
|
||||
restore := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
}
|
||||
err = restoreRepository.Save(restore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = GetRestoresScheduler().StartRestore(restore.ID, nil)
|
||||
assert.NoError(t, err)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Get initial state of the registry
|
||||
initialStats, err := restoreNodesRegistry.GetRestoreNodesStats()
|
||||
assert.NoError(t, err)
|
||||
var initialActiveTasks int
|
||||
for _, stat := range initialStats {
|
||||
if stat.ID == mockNodeID {
|
||||
initialActiveTasks = stat.ActiveRestores
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, initialActiveTasks, "Should have 1 active task")
|
||||
|
||||
// Call onRestoreCompleted with a random UUID (not a restore ID)
|
||||
nonRestoreTaskID := uuid.New()
|
||||
GetRestoresScheduler().onRestoreCompleted(mockNodeID, nonRestoreTaskID)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify: Active tasks counter should remain the same (not decremented)
|
||||
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
|
||||
assert.NoError(t, err)
|
||||
for _, stat := range stats {
|
||||
if stat.ID == mockNodeID {
|
||||
assert.Equal(t, initialActiveTasks, stat.ActiveRestores,
|
||||
"Active tasks should not change for non-restore task")
|
||||
}
|
||||
}
|
||||
|
||||
// Verify: restore should still be in progress (not modified)
|
||||
unchangedRestore, err := restoreRepository.FindByID(restore.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, restores_core.RestoreStatusInProgress, unchangedRestore.Status,
|
||||
"Restore status should not change for non-restore task completion")
|
||||
|
||||
// Verify: restoreToNodeRelations should still contain the node
|
||||
scheduler := GetRestoresScheduler()
|
||||
_, exists := scheduler.restoreToNodeRelations[mockNodeID]
|
||||
assert.True(t, exists, "Node should still be in restoreToNodeRelations")
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
|
||||
t.Run("Nodes with same throughput", func(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
node1ID := uuid.New()
|
||||
node2ID := uuid.New()
|
||||
node3ID := uuid.New()
|
||||
now := time.Now().UTC()
|
||||
|
||||
defer func() {
|
||||
// Clean up all mock nodes
|
||||
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: node1ID})
|
||||
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: node2ID})
|
||||
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: node3ID})
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
err := CreateMockNodeInRegistry(node1ID, 100, now)
|
||||
assert.NoError(t, err)
|
||||
err = CreateMockNodeInRegistry(node2ID, 100, now)
|
||||
assert.NoError(t, err)
|
||||
err = CreateMockNodeInRegistry(node3ID, 100, now)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for range 5 {
|
||||
err = restoreNodesRegistry.IncrementRestoresInProgress(node1ID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
for range 2 {
|
||||
err = restoreNodesRegistry.IncrementRestoresInProgress(node2ID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
for range 8 {
|
||||
err = restoreNodesRegistry.IncrementRestoresInProgress(node3ID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
leastBusyNodeID, err := GetRestoresScheduler().calculateLeastBusyNode()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, leastBusyNodeID)
|
||||
assert.Equal(t, node2ID, *leastBusyNodeID)
|
||||
})
|
||||
|
||||
t.Run("Nodes with different throughput", func(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
node100MBsID := uuid.New()
|
||||
node50MBsID := uuid.New()
|
||||
now := time.Now().UTC()
|
||||
|
||||
defer func() {
|
||||
// Clean up all mock nodes
|
||||
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: node100MBsID})
|
||||
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: node50MBsID})
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
err := CreateMockNodeInRegistry(node100MBsID, 100, now)
|
||||
assert.NoError(t, err)
|
||||
err = CreateMockNodeInRegistry(node50MBsID, 50, now)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for range 10 {
|
||||
err = restoreNodesRegistry.IncrementRestoresInProgress(node100MBsID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
err = restoreNodesRegistry.IncrementRestoresInProgress(node50MBsID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
leastBusyNodeID, err := GetRestoresScheduler().calculateLeastBusyNode()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, leastBusyNodeID)
|
||||
assert.Equal(t, node50MBsID, *leastBusyNodeID)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_FailRestoresInProgress_SchedulerStarts_UpdatesStatus(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backupRepo := backups_core.BackupRepository{}
|
||||
backups, _ := backupRepo.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
restoreRepo := restores_core.RestoreRepository{}
|
||||
|
||||
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
restores, _ = restoreRepo.FindByStatus(restores_core.RestoreStatusFailed)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
restores, _ = restoreRepo.FindByStatus(restores_core.RestoreStatusCompleted)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
// Create a test backup
|
||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
||||
|
||||
// Create two in-progress restores that should be failed on scheduler restart
|
||||
restore1 := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
CreatedAt: time.Now().UTC().Add(-30 * time.Minute),
|
||||
}
|
||||
err := restoreRepository.Save(restore1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
restore2 := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
CreatedAt: time.Now().UTC().Add(-15 * time.Minute),
|
||||
}
|
||||
err = restoreRepository.Save(restore2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create a completed restore to verify it's not affected by failRestoresInProgress
|
||||
completedRestore := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusCompleted,
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
}
|
||||
err = restoreRepository.Save(completedRestore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Trigger the scheduler's failRestoresInProgress logic
|
||||
// This should mark in-progress restores as failed
|
||||
err = GetRestoresScheduler().failRestoresInProgress()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify all restores exist and were processed correctly
|
||||
allRestores1, err := restoreRepository.FindByID(restore1.ID)
|
||||
assert.NoError(t, err)
|
||||
allRestores2, err := restoreRepository.FindByID(restore2.ID)
|
||||
assert.NoError(t, err)
|
||||
allRestores3, err := restoreRepository.FindByID(completedRestore.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var failedCount int
|
||||
var completedCount int
|
||||
|
||||
restoresToCheck := []*restores_core.Restore{allRestores1, allRestores2, allRestores3}
|
||||
for _, restore := range restoresToCheck {
|
||||
switch restore.Status {
|
||||
case restores_core.RestoreStatusFailed:
|
||||
failedCount++
|
||||
// Verify fail message indicates application restart
|
||||
assert.NotNil(t, restore.FailMessage)
|
||||
assert.Equal(t, "Restore failed due to application restart", *restore.FailMessage)
|
||||
case restores_core.RestoreStatusCompleted:
|
||||
completedCount++
|
||||
}
|
||||
}
|
||||
|
||||
// Verify correct number of restores in each state
|
||||
assert.Equal(t, 2, failedCount, "Should have 2 failed restores (originally in progress)")
|
||||
assert.Equal(t, 1, completedCount, "Should have 1 completed restore (unchanged)")
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_StartRestore_RestoreCompletes_DecrementsActiveTaskCount(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
// Start scheduler so it can handle task completions
|
||||
scheduler := CreateTestRestoresScheduler()
|
||||
schedulerCancel := StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
restorerNode := CreateTestRestorerNode()
|
||||
restorerNode.restoreBackupUsecase = &MockSuccessRestoreUsecase{}
|
||||
|
||||
cancel := StartRestorerNodeForTest(t, restorerNode)
|
||||
defer StopRestorerNodeForTest(t, cancel, restorerNode)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backupRepo := backups_core.BackupRepository{}
|
||||
backups, _ := backupRepo.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
restoreRepo := restores_core.RestoreRepository{}
|
||||
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusCompleted)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
// Create a test backup
|
||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
||||
|
||||
// Get initial active task count
|
||||
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
|
||||
assert.NoError(t, err)
|
||||
var initialActiveTasks int
|
||||
for _, stat := range stats {
|
||||
if stat.ID == restorerNode.nodeID {
|
||||
initialActiveTasks = stat.ActiveRestores
|
||||
break
|
||||
}
|
||||
}
|
||||
t.Logf("Initial active tasks: %d", initialActiveTasks)
|
||||
|
||||
// Create and start restore
|
||||
restore := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
}
|
||||
err = restoreRepository.Save(restore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = scheduler.StartRestore(restore.ID, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Wait for restore to complete
|
||||
WaitForRestoreCompletion(t, restore.ID, 10*time.Second)
|
||||
|
||||
// Verify restore was completed
|
||||
completedRestore, err := restoreRepository.FindByID(restore.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, completedRestore.Status)
|
||||
|
||||
// Wait for active task count to decrease
|
||||
decreased := WaitForActiveTasksDecrease(
|
||||
t,
|
||||
restorerNode.nodeID,
|
||||
initialActiveTasks+1,
|
||||
10*time.Second,
|
||||
)
|
||||
assert.True(t, decreased, "Active task count should have decreased after restore completion")
|
||||
|
||||
// Verify final active task count equals initial count
|
||||
finalStats, err := restoreNodesRegistry.GetRestoreNodesStats()
|
||||
assert.NoError(t, err)
|
||||
for _, stat := range finalStats {
|
||||
if stat.ID == restorerNode.nodeID {
|
||||
t.Logf("Final active tasks: %d", stat.ActiveRestores)
|
||||
assert.Equal(t, initialActiveTasks, stat.ActiveRestores,
|
||||
"Active task count should return to initial value after restore completion")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_StartRestore_RestoreFails_DecrementsActiveTaskCount(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
// Start scheduler so it can handle task completions
|
||||
scheduler := CreateTestRestoresScheduler()
|
||||
schedulerCancel := StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
restorerNode := CreateTestRestorerNode()
|
||||
restorerNode.restoreBackupUsecase = &MockFailedRestoreUsecase{}
|
||||
|
||||
cancel := StartRestorerNodeForTest(t, restorerNode)
|
||||
defer StopRestorerNodeForTest(t, cancel, restorerNode)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backupRepo := backups_core.BackupRepository{}
|
||||
backups, _ := backupRepo.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
restoreRepo := restores_core.RestoreRepository{}
|
||||
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusFailed)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
// Create a test backup
|
||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
||||
|
||||
// Get initial active task count
|
||||
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
|
||||
assert.NoError(t, err)
|
||||
var initialActiveTasks int
|
||||
for _, stat := range stats {
|
||||
if stat.ID == restorerNode.nodeID {
|
||||
initialActiveTasks = stat.ActiveRestores
|
||||
break
|
||||
}
|
||||
}
|
||||
t.Logf("Initial active tasks: %d", initialActiveTasks)
|
||||
|
||||
// Create and start restore
|
||||
restore := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
}
|
||||
err = restoreRepository.Save(restore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = scheduler.StartRestore(restore.ID, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Wait for restore to fail
|
||||
WaitForRestoreCompletion(t, restore.ID, 10*time.Second)
|
||||
|
||||
// Verify restore failed
|
||||
failedRestore, err := restoreRepository.FindByID(restore.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, restores_core.RestoreStatusFailed, failedRestore.Status)
|
||||
|
||||
// Wait for active task count to decrease
|
||||
decreased := WaitForActiveTasksDecrease(
|
||||
t,
|
||||
restorerNode.nodeID,
|
||||
initialActiveTasks+1,
|
||||
10*time.Second,
|
||||
)
|
||||
assert.True(t, decreased, "Active task count should have decreased after restore failure")
|
||||
|
||||
// Verify final active task count equals initial count
|
||||
finalStats, err := restoreNodesRegistry.GetRestoreNodesStats()
|
||||
assert.NoError(t, err)
|
||||
for _, stat := range finalStats {
|
||||
if stat.ID == restorerNode.nodeID {
|
||||
t.Logf("Final active tasks: %d", stat.ActiveRestores)
|
||||
assert.Equal(t, initialActiveTasks, stat.ActiveRestores,
|
||||
"Active task count should return to initial value after restore failure")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_StartRestore_CredentialsStoredEncryptedInCache(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
var mockNodeID uuid.UUID
|
||||
|
||||
defer func() {
|
||||
backupRepo := backups_core.BackupRepository{}
|
||||
backups, _ := backupRepo.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
restoreRepo := restores_core.RestoreRepository{}
|
||||
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
// Clean up mock node
|
||||
if mockNodeID != uuid.Nil {
|
||||
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: mockNodeID})
|
||||
}
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
// Create a test backup
|
||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
||||
|
||||
// Register mock node so scheduler can assign restore to it
|
||||
mockNodeID = uuid.New()
|
||||
err := CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC())
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create restore with plaintext credentials
|
||||
plaintextPassword := "test_password_123"
|
||||
restore := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
}
|
||||
err = restoreRepository.Save(restore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create PostgreSQL database credentials with plaintext password
|
||||
postgresDB := &postgresql.PostgresqlDatabase{
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "testuser",
|
||||
Password: plaintextPassword,
|
||||
Database: stringPtr("testdb"),
|
||||
Version: "16",
|
||||
}
|
||||
|
||||
// Encrypt password using FieldEncryptor (same as production flow)
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
err = postgresDB.EncryptSensitiveFields(database.ID, encryptor)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify password was encrypted (different from plaintext)
|
||||
assert.NotEqual(t, plaintextPassword, postgresDB.Password,
|
||||
"Password should be encrypted, not plaintext")
|
||||
|
||||
// Create cache with encrypted credentials
|
||||
dbCache := &RestoreDatabaseCache{
|
||||
PostgresqlDatabase: postgresDB,
|
||||
}
|
||||
|
||||
// Call StartRestore to cache credentials (do NOT start restore node)
|
||||
err = GetRestoresScheduler().StartRestore(restore.ID, dbCache)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Directly read from cache
|
||||
cachedData := restoreDatabaseCache.Get(restore.ID.String())
|
||||
assert.NotNil(t, cachedData, "Cache entry should exist")
|
||||
assert.NotNil(t, cachedData.PostgresqlDatabase, "PostgreSQL credentials should be cached")
|
||||
|
||||
// Verify password in cache is encrypted (not plaintext)
|
||||
assert.NotEqual(t, plaintextPassword, cachedData.PostgresqlDatabase.Password,
|
||||
"Cached password should be encrypted, not plaintext")
|
||||
assert.Equal(t, postgresDB.Password, cachedData.PostgresqlDatabase.Password,
|
||||
"Cached password should match the encrypted version")
|
||||
|
||||
// Verify other fields are present
|
||||
assert.Equal(t, config.GetEnv().TestLocalhost, cachedData.PostgresqlDatabase.Host)
|
||||
assert.Equal(t, 5432, cachedData.PostgresqlDatabase.Port)
|
||||
assert.Equal(t, "testuser", cachedData.PostgresqlDatabase.Username)
|
||||
assert.Equal(t, "testdb", *cachedData.PostgresqlDatabase.Database)
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_StartRestore_CredentialsRemovedAfterRestoreStarts(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
// Start scheduler so it can handle task assignments
|
||||
scheduler := CreateTestRestoresScheduler()
|
||||
schedulerCancel := StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
// Create mock restorer node with credential capture usecase
|
||||
restorerNode := CreateTestRestorerNode()
|
||||
calledChan := make(chan *databases.Database, 1)
|
||||
restorerNode.restoreBackupUsecase = &MockCaptureCredentialsRestoreUsecase{
|
||||
CalledChan: calledChan,
|
||||
ShouldSucceed: true,
|
||||
}
|
||||
|
||||
cancel := StartRestorerNodeForTest(t, restorerNode)
|
||||
defer StopRestorerNodeForTest(t, cancel, restorerNode)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backupRepo := backups_core.BackupRepository{}
|
||||
backups, _ := backupRepo.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
restoreRepo := restores_core.RestoreRepository{}
|
||||
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusCompleted)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
// Create a test backup
|
||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
||||
|
||||
// Create restore with credentials
|
||||
plaintextPassword := "test_password_456"
|
||||
restore := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
}
|
||||
err := restoreRepository.Save(restore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create PostgreSQL database credentials
|
||||
// Database field is nil to avoid PopulateDbData trying to connect
|
||||
postgresDB := &postgresql.PostgresqlDatabase{
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "testuser",
|
||||
Password: plaintextPassword,
|
||||
Database: nil,
|
||||
Version: "16",
|
||||
}
|
||||
|
||||
// Encrypt password (same as production flow)
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
err = postgresDB.EncryptSensitiveFields(database.ID, encryptor)
|
||||
assert.NoError(t, err)
|
||||
|
||||
encryptedPassword := postgresDB.Password
|
||||
|
||||
// Create cache with encrypted credentials
|
||||
dbCache := &RestoreDatabaseCache{
|
||||
PostgresqlDatabase: postgresDB,
|
||||
}
|
||||
|
||||
// Call StartRestore to cache credentials and trigger restore
|
||||
err = scheduler.StartRestore(restore.ID, dbCache)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Wait for mock usecase to be called (with timeout)
|
||||
var capturedDB *databases.Database
|
||||
select {
|
||||
case capturedDB = <-calledChan:
|
||||
t.Log("Mock usecase was called, credentials captured")
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Fatal("Timeout waiting for mock usecase to be called")
|
||||
}
|
||||
|
||||
// Verify cache is empty after restore starts (credentials were deleted)
|
||||
cacheAfterExecution := restoreDatabaseCache.Get(restore.ID.String())
|
||||
assert.Nil(t, cacheAfterExecution, "Cache should be empty after restore execution starts")
|
||||
|
||||
// Verify mock received valid credentials
|
||||
assert.NotNil(t, capturedDB, "Captured database should not be nil")
|
||||
assert.NotNil(t, capturedDB.Postgresql, "PostgreSQL credentials should be provided to usecase")
|
||||
assert.Equal(t, config.GetEnv().TestLocalhost, capturedDB.Postgresql.Host)
|
||||
assert.Equal(t, 5432, capturedDB.Postgresql.Port)
|
||||
assert.Equal(t, "testuser", capturedDB.Postgresql.Username)
|
||||
assert.NotEmpty(t, capturedDB.Postgresql.Password, "Password should be provided to usecase")
|
||||
|
||||
// Note: Password at this point may still be encrypted because PopulateDbData
|
||||
// is called after the mock captures it. The important thing is that credentials
|
||||
// were provided to the usecase despite cache being deleted.
|
||||
t.Logf("Encrypted password in cache: %s", encryptedPassword)
|
||||
t.Logf("Password received by usecase: %s", capturedDB.Postgresql.Password)
|
||||
|
||||
// Wait for restore to complete
|
||||
WaitForRestoreCompletion(t, restore.ID, 10*time.Second)
|
||||
|
||||
// Verify restore was completed
|
||||
completedRestore, err := restoreRepository.FindByID(restore.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, completedRestore.Status)
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
342
backend/internal/features/restores/restoring/testing.go
Normal file
342
backend/internal/features/restores/restoring/testing.go
Normal file
@@ -0,0 +1,342 @@
|
||||
package restoring
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/restores/usecases"
|
||||
"databasus-backend/internal/features/storages"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
func CreateTestRouter() *gin.Engine {
|
||||
router := workspaces_testing.CreateTestRouter(
|
||||
workspaces_controllers.GetWorkspaceController(),
|
||||
workspaces_controllers.GetMembershipController(),
|
||||
databases.GetDatabaseController(),
|
||||
backups_config.GetBackupConfigController(),
|
||||
)
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
func CreateTestRestorerNode() *RestorerNode {
|
||||
return &RestorerNode{
|
||||
nodeID: uuid.New(),
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
backupService: backups.GetBackupService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
restoreRepository: restoreRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
restoreNodesRegistry: restoreNodesRegistry,
|
||||
logger: logger.GetLogger(),
|
||||
restoreBackupUsecase: usecases.GetRestoreBackupUsecase(),
|
||||
cacheUtil: restoreDatabaseCache,
|
||||
restoreCancelManager: tasks_cancellation.GetTaskCancelManager(),
|
||||
lastHeartbeat: time.Time{},
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateTestRestorerNodeWithUsecase(usecase restores_core.RestoreBackupUsecase) *RestorerNode {
|
||||
return &RestorerNode{
|
||||
nodeID: uuid.New(),
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
backupService: backups.GetBackupService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
restoreRepository: restoreRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
restoreNodesRegistry: restoreNodesRegistry,
|
||||
logger: logger.GetLogger(),
|
||||
restoreBackupUsecase: usecase,
|
||||
cacheUtil: restoreDatabaseCache,
|
||||
restoreCancelManager: tasks_cancellation.GetTaskCancelManager(),
|
||||
lastHeartbeat: time.Time{},
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateTestRestoresScheduler() *RestoresScheduler {
|
||||
return &RestoresScheduler{
|
||||
restoreRepository,
|
||||
backups.GetBackupService(),
|
||||
storages.GetStorageService(),
|
||||
backups_config.GetBackupConfigService(),
|
||||
restoreNodesRegistry,
|
||||
time.Now().UTC(),
|
||||
logger.GetLogger(),
|
||||
make(map[uuid.UUID]RestoreToNodeRelation),
|
||||
restorerNode,
|
||||
restoreDatabaseCache,
|
||||
uuid.Nil,
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
// WaitForRestoreCompletion waits for a restore to be completed (or failed)
|
||||
func WaitForRestoreCompletion(
|
||||
t *testing.T,
|
||||
restoreID uuid.UUID,
|
||||
timeout time.Duration,
|
||||
) {
|
||||
deadline := time.Now().UTC().Add(timeout)
|
||||
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
restore, err := restoreRepository.FindByID(restoreID)
|
||||
if err != nil {
|
||||
t.Logf("WaitForRestoreCompletion: error finding restore: %v", err)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
t.Logf("WaitForRestoreCompletion: restore status: %s", restore.Status)
|
||||
|
||||
if restore.Status == restores_core.RestoreStatusCompleted ||
|
||||
restore.Status == restores_core.RestoreStatusFailed {
|
||||
t.Logf(
|
||||
"WaitForRestoreCompletion: restore finished with status %s",
|
||||
restore.Status,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Logf("WaitForRestoreCompletion: timeout waiting for restore to complete")
|
||||
}
|
||||
|
||||
// StartRestorerNodeForTest starts a RestorerNode in a goroutine for testing.
|
||||
// The node registers itself in the registry and subscribes to restore assignments.
|
||||
// Returns a context cancel function that should be deferred to stop the node.
|
||||
func StartRestorerNodeForTest(t *testing.T, restorerNode *RestorerNode) context.CancelFunc {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
restorerNode.Run(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Poll registry for node presence instead of fixed sleep
|
||||
deadline := time.Now().UTC().Add(5 * time.Second)
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
nodes, err := restoreNodesRegistry.GetAvailableNodes()
|
||||
if err == nil {
|
||||
for _, node := range nodes {
|
||||
if node.ID == restorerNode.nodeID {
|
||||
t.Logf("RestorerNode registered in registry: %s", restorerNode.nodeID)
|
||||
|
||||
return func() {
|
||||
cancel()
|
||||
select {
|
||||
case <-done:
|
||||
t.Log("RestorerNode stopped gracefully")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Log("RestorerNode stop timeout")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Fatalf("RestorerNode failed to register in registry within timeout")
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartSchedulerForTest starts the RestoresScheduler in a goroutine for testing.
|
||||
// The scheduler subscribes to task completions and manages restore lifecycle.
|
||||
// Returns a context cancel function that should be deferred to stop the scheduler.
|
||||
func StartSchedulerForTest(t *testing.T, scheduler *RestoresScheduler) context.CancelFunc {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
scheduler.Run(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Give scheduler time to subscribe to completions
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
t.Log("RestoresScheduler started")
|
||||
|
||||
return func() {
|
||||
cancel()
|
||||
select {
|
||||
case <-done:
|
||||
t.Log("RestoresScheduler stopped gracefully")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Log("RestoresScheduler stop timeout")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StopRestorerNodeForTest stops the RestorerNode by canceling its context.
|
||||
// It waits for the node to unregister from the registry.
|
||||
func StopRestorerNodeForTest(t *testing.T, cancel context.CancelFunc, restorerNode *RestorerNode) {
|
||||
cancel()
|
||||
|
||||
// Wait for node to unregister from registry
|
||||
deadline := time.Now().UTC().Add(2 * time.Second)
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
nodes, err := restoreNodesRegistry.GetAvailableNodes()
|
||||
if err == nil {
|
||||
found := false
|
||||
for _, node := range nodes {
|
||||
if node.ID == restorerNode.nodeID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Logf("RestorerNode unregistered from registry: %s", restorerNode.nodeID)
|
||||
return
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Logf("RestorerNode stop completed for %s", restorerNode.nodeID)
|
||||
}
|
||||
|
||||
func CreateMockNodeInRegistry(nodeID uuid.UUID, throughputMBs int, lastHeartbeat time.Time) error {
|
||||
restoreNode := RestoreNode{
|
||||
ID: nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: lastHeartbeat,
|
||||
}
|
||||
|
||||
return restoreNodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, restoreNode)
|
||||
}
|
||||
|
||||
func UpdateNodeHeartbeatDirectly(
|
||||
nodeID uuid.UUID,
|
||||
throughputMBs int,
|
||||
lastHeartbeat time.Time,
|
||||
) error {
|
||||
restoreNode := RestoreNode{
|
||||
ID: nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: lastHeartbeat,
|
||||
}
|
||||
|
||||
return restoreNodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, restoreNode)
|
||||
}
|
||||
|
||||
func GetNodeFromRegistry(nodeID uuid.UUID) (*RestoreNode, error) {
|
||||
nodes, err := restoreNodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
if node.ID == nodeID {
|
||||
return &node, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("node not found")
|
||||
}
|
||||
|
||||
// WaitForActiveTasksDecrease waits for the active task count to decrease below the initial count.
|
||||
// It polls the registry every 500ms until the count decreases or the timeout is reached.
|
||||
// Returns true if the count decreased, false if timeout was reached.
|
||||
func WaitForActiveTasksDecrease(
|
||||
t *testing.T,
|
||||
nodeID uuid.UUID,
|
||||
initialCount int,
|
||||
timeout time.Duration,
|
||||
) bool {
|
||||
deadline := time.Now().UTC().Add(timeout)
|
||||
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
|
||||
if err != nil {
|
||||
t.Logf("WaitForActiveTasksDecrease: error getting node stats: %v", err)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, stat := range stats {
|
||||
if stat.ID == nodeID {
|
||||
t.Logf(
|
||||
"WaitForActiveTasksDecrease: current active tasks = %d (initial = %d)",
|
||||
stat.ActiveRestores,
|
||||
initialCount,
|
||||
)
|
||||
if stat.ActiveRestores < initialCount {
|
||||
t.Logf(
|
||||
"WaitForActiveTasksDecrease: active tasks decreased from %d to %d",
|
||||
initialCount,
|
||||
stat.ActiveRestores,
|
||||
)
|
||||
return true
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Logf("WaitForActiveTasksDecrease: timeout waiting for active tasks to decrease")
|
||||
return false
|
||||
}
|
||||
|
||||
// CreateTestRestore creates a test restore with the given backup and status
|
||||
func CreateTestRestore(
|
||||
t *testing.T,
|
||||
backup *backups_core.Backup,
|
||||
status restores_core.RestoreStatus,
|
||||
) *restores_core.Restore {
|
||||
restore := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: status,
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "test",
|
||||
Password: "test",
|
||||
Database: stringPtr("testdb"),
|
||||
Version: "16",
|
||||
},
|
||||
}
|
||||
|
||||
err := restoreRepository.Save(restore)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test restore: %v", err)
|
||||
}
|
||||
|
||||
return restore
|
||||
}
|
||||
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
@@ -7,10 +7,11 @@ import (
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/disk"
|
||||
"databasus-backend/internal/features/restores/enums"
|
||||
"databasus-backend/internal/features/restores/models"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/restores/restoring"
|
||||
"databasus-backend/internal/features/restores/usecases"
|
||||
"databasus-backend/internal/features/storages"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
users_models "databasus-backend/internal/features/users/models"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
@@ -25,7 +26,7 @@ import (
|
||||
|
||||
type RestoreService struct {
|
||||
backupService *backups.BackupService
|
||||
restoreRepository *RestoreRepository
|
||||
restoreRepository *restores_core.RestoreRepository
|
||||
storageService *storages.StorageService
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
restoreBackupUsecase *usecases.RestoreBackupUsecase
|
||||
@@ -35,6 +36,7 @@ type RestoreService struct {
|
||||
auditLogService *audit_logs.AuditLogService
|
||||
fieldEncryptor encryption.FieldEncryptor
|
||||
diskService *disk.DiskService
|
||||
taskCancelManager *tasks_cancellation.TaskCancelManager
|
||||
}
|
||||
|
||||
func (s *RestoreService) OnBeforeBackupRemove(backup *backups_core.Backup) error {
|
||||
@@ -44,7 +46,7 @@ func (s *RestoreService) OnBeforeBackupRemove(backup *backups_core.Backup) error
|
||||
}
|
||||
|
||||
for _, restore := range restores {
|
||||
if restore.Status == enums.RestoreStatusInProgress {
|
||||
if restore.Status == restores_core.RestoreStatusInProgress {
|
||||
return errors.New("restore is in progress, backup cannot be removed")
|
||||
}
|
||||
}
|
||||
@@ -61,7 +63,7 @@ func (s *RestoreService) OnBeforeBackupRemove(backup *backups_core.Backup) error
|
||||
func (s *RestoreService) GetRestores(
|
||||
user *users_models.User,
|
||||
backupID uuid.UUID,
|
||||
) ([]*models.Restore, error) {
|
||||
) ([]*restores_core.Restore, error) {
|
||||
backup, err := s.backupService.GetBackup(backupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -93,7 +95,7 @@ func (s *RestoreService) GetRestores(
|
||||
func (s *RestoreService) RestoreBackupWithAuth(
|
||||
user *users_models.User,
|
||||
backupID uuid.UUID,
|
||||
requestDTO RestoreBackupRequest,
|
||||
requestDTO restores_core.RestoreBackupRequest,
|
||||
) error {
|
||||
backup, err := s.backupService.GetBackup(backupID)
|
||||
if err != nil {
|
||||
@@ -134,11 +136,50 @@ func (s *RestoreService) RestoreBackupWithAuth(
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := s.RestoreBackup(backup, requestDTO); err != nil {
|
||||
s.logger.Error("Failed to restore backup", "error", err)
|
||||
// Validate no parallel restores for the same database
|
||||
if err := s.validateNoParallelRestores(backup.DatabaseID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create restore record with the request configuration
|
||||
restore := restores_core.Restore{
|
||||
ID: uuid.New(),
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
BackupID: backup.ID,
|
||||
Backup: backup,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
RestoreDurationMs: 0,
|
||||
FailMessage: nil,
|
||||
PostgresqlDatabase: requestDTO.PostgresqlDatabase,
|
||||
MysqlDatabase: requestDTO.MysqlDatabase,
|
||||
MariadbDatabase: requestDTO.MariadbDatabase,
|
||||
MongodbDatabase: requestDTO.MongodbDatabase,
|
||||
}
|
||||
|
||||
if err := s.restoreRepository.Save(&restore); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Prepare database cache with credentials from the request
|
||||
dbCache := &restoring.RestoreDatabaseCache{
|
||||
PostgresqlDatabase: requestDTO.PostgresqlDatabase,
|
||||
MysqlDatabase: requestDTO.MysqlDatabase,
|
||||
MariadbDatabase: requestDTO.MariadbDatabase,
|
||||
MongodbDatabase: requestDTO.MongodbDatabase,
|
||||
}
|
||||
|
||||
// Trigger restore via scheduler
|
||||
scheduler := restoring.GetRestoresScheduler()
|
||||
if err := scheduler.StartRestore(restore.ID, dbCache); err != nil {
|
||||
// Mark restore as failed if we can't schedule it
|
||||
failMsg := fmt.Sprintf("Failed to schedule restore: %v", err)
|
||||
restore.FailMessage = &failMsg
|
||||
restore.Status = restores_core.RestoreStatusFailed
|
||||
if saveErr := s.restoreRepository.Save(&restore); saveErr != nil {
|
||||
s.logger.Error("Failed to save restore after scheduling error", "error", saveErr)
|
||||
}
|
||||
}()
|
||||
return err
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
@@ -153,127 +194,9 @@ func (s *RestoreService) RestoreBackupWithAuth(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RestoreService) RestoreBackup(
|
||||
backup *backups_core.Backup,
|
||||
requestDTO RestoreBackupRequest,
|
||||
) error {
|
||||
if backup.Status != backups_core.BackupStatusCompleted {
|
||||
return errors.New("backup is not completed")
|
||||
}
|
||||
|
||||
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch database.Type {
|
||||
case databases.DatabaseTypePostgres:
|
||||
if requestDTO.PostgresqlDatabase == nil {
|
||||
return errors.New("postgresql database is required")
|
||||
}
|
||||
case databases.DatabaseTypeMysql:
|
||||
if requestDTO.MysqlDatabase == nil {
|
||||
return errors.New("mysql database is required")
|
||||
}
|
||||
case databases.DatabaseTypeMariadb:
|
||||
if requestDTO.MariadbDatabase == nil {
|
||||
return errors.New("mariadb database is required")
|
||||
}
|
||||
case databases.DatabaseTypeMongodb:
|
||||
if requestDTO.MongodbDatabase == nil {
|
||||
return errors.New("mongodb database is required")
|
||||
}
|
||||
}
|
||||
|
||||
restore := models.Restore{
|
||||
ID: uuid.New(),
|
||||
Status: enums.RestoreStatusInProgress,
|
||||
|
||||
BackupID: backup.ID,
|
||||
Backup: backup,
|
||||
|
||||
CreatedAt: time.Now().UTC(),
|
||||
RestoreDurationMs: 0,
|
||||
|
||||
FailMessage: nil,
|
||||
}
|
||||
|
||||
// Save the restore first
|
||||
if err := s.restoreRepository.Save(&restore); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Save the restore again to include the postgresql database
|
||||
if err := s.restoreRepository.Save(&restore); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
storage, err := s.storageService.GetStorageByID(backup.StorageID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(
|
||||
database.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
start := time.Now().UTC()
|
||||
|
||||
restoringToDB := &databases.Database{
|
||||
Type: database.Type,
|
||||
Postgresql: requestDTO.PostgresqlDatabase,
|
||||
Mysql: requestDTO.MysqlDatabase,
|
||||
Mariadb: requestDTO.MariadbDatabase,
|
||||
Mongodb: requestDTO.MongodbDatabase,
|
||||
}
|
||||
|
||||
if err := restoringToDB.PopulateDbData(s.logger, s.fieldEncryptor); err != nil {
|
||||
return fmt.Errorf("failed to auto-detect database data: %w", err)
|
||||
}
|
||||
|
||||
isExcludeExtensions := false
|
||||
if requestDTO.PostgresqlDatabase != nil {
|
||||
isExcludeExtensions = requestDTO.PostgresqlDatabase.IsExcludeExtensions
|
||||
}
|
||||
|
||||
err = s.restoreBackupUsecase.Execute(
|
||||
backupConfig,
|
||||
restore,
|
||||
database,
|
||||
restoringToDB,
|
||||
backup,
|
||||
storage,
|
||||
isExcludeExtensions,
|
||||
)
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
restore.FailMessage = &errMsg
|
||||
restore.Status = enums.RestoreStatusFailed
|
||||
restore.RestoreDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
if err := s.restoreRepository.Save(&restore); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
restore.Status = enums.RestoreStatusCompleted
|
||||
restore.RestoreDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
if err := s.restoreRepository.Save(&restore); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RestoreService) validateVersionCompatibility(
|
||||
backupDatabase *databases.Database,
|
||||
requestDTO RestoreBackupRequest,
|
||||
requestDTO restores_core.RestoreBackupRequest,
|
||||
) error {
|
||||
// populate version
|
||||
if requestDTO.MariadbDatabase != nil {
|
||||
@@ -372,7 +295,7 @@ func (s *RestoreService) validateVersionCompatibility(
|
||||
|
||||
func (s *RestoreService) validateDiskSpace(
|
||||
backup *backups_core.Backup,
|
||||
requestDTO RestoreBackupRequest,
|
||||
requestDTO restores_core.RestoreBackupRequest,
|
||||
) error {
|
||||
// Only validate disk space for PostgreSQL when file-based restore is needed:
|
||||
// - CPU > 1 (parallel jobs require file)
|
||||
@@ -424,3 +347,71 @@ func (s *RestoreService) validateDiskSpace(
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RestoreService) validateNoParallelRestores(databaseID uuid.UUID) error {
|
||||
inProgressRestores, err := s.restoreRepository.FindInProgressRestoresByDatabaseID(databaseID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check for in-progress restores: %w", err)
|
||||
}
|
||||
|
||||
isInProgress := len(inProgressRestores) > 0
|
||||
if isInProgress {
|
||||
return errors.New(
|
||||
"another restore is already in progress for this database. Please wait for it to complete or cancel it before starting a new restore",
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RestoreService) CancelRestore(
|
||||
user *users_models.User,
|
||||
restoreID uuid.UUID,
|
||||
) error {
|
||||
restore, err := s.restoreRepository.FindByID(restoreID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
backup, err := s.backupService.GetBackup(restore.BackupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if database.WorkspaceID == nil {
|
||||
return errors.New("cannot cancel restore for database without workspace")
|
||||
}
|
||||
|
||||
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !canManage {
|
||||
return errors.New("insufficient permissions to cancel restore for this database")
|
||||
}
|
||||
|
||||
if restore.Status != restores_core.RestoreStatusInProgress {
|
||||
return errors.New("restore is not in progress")
|
||||
}
|
||||
|
||||
if err := s.taskCancelManager.CancelTask(restoreID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
"Restore cancelled for database: %s (ID: %s)",
|
||||
database.Name,
|
||||
restoreID.String(),
|
||||
),
|
||||
&user.ID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
51
backend/internal/features/restores/testing.go
Normal file
51
backend/internal/features/restores/testing.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package restores
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/restores/restoring"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
)
|
||||
|
||||
func CreateTestRouter() *gin.Engine {
|
||||
router := workspaces_testing.CreateTestRouter(
|
||||
workspaces_controllers.GetWorkspaceController(),
|
||||
workspaces_controllers.GetMembershipController(),
|
||||
databases.GetDatabaseController(),
|
||||
backups_config.GetBackupConfigController(),
|
||||
backups.GetBackupController(),
|
||||
GetRestoreController(),
|
||||
)
|
||||
|
||||
v1 := router.Group("/api/v1")
|
||||
backups.GetBackupController().RegisterPublicRoutes(v1)
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
func SetupMockRestoreNode(t *testing.T) (uuid.UUID, context.CancelFunc) {
|
||||
nodeID := uuid.New()
|
||||
err := restoring.CreateMockNodeInRegistry(
|
||||
nodeID,
|
||||
100,
|
||||
time.Now().UTC(),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create mock node: %v", err)
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
// Node will expire naturally from registry
|
||||
}
|
||||
|
||||
return nodeID, cleanup
|
||||
}
|
||||
@@ -24,7 +24,7 @@ import (
|
||||
"databasus-backend/internal/features/databases"
|
||||
mariadbtypes "databasus-backend/internal/features/databases/databases/mariadb"
|
||||
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
|
||||
"databasus-backend/internal/features/restores/models"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/tools"
|
||||
@@ -36,10 +36,11 @@ type RestoreMariadbBackupUsecase struct {
|
||||
}
|
||||
|
||||
func (uc *RestoreMariadbBackupUsecase) Execute(
|
||||
parentCtx context.Context,
|
||||
originalDB *databases.Database,
|
||||
restoringToDB *databases.Database,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore models.Restore,
|
||||
restore restores_core.Restore,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
) error {
|
||||
@@ -79,6 +80,7 @@ func (uc *RestoreMariadbBackupUsecase) Execute(
|
||||
}
|
||||
|
||||
return uc.restoreFromStorage(
|
||||
parentCtx,
|
||||
originalDB,
|
||||
tools.GetMariadbExecutable(
|
||||
tools.MariadbExecutableMariadb,
|
||||
@@ -95,6 +97,7 @@ func (uc *RestoreMariadbBackupUsecase) Execute(
|
||||
}
|
||||
|
||||
func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
|
||||
parentCtx context.Context,
|
||||
database *databases.Database,
|
||||
mariadbBin string,
|
||||
args []string,
|
||||
@@ -103,7 +106,7 @@ func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
|
||||
storage *storages.Storage,
|
||||
mdbConfig *mariadbtypes.MariadbDatabase,
|
||||
) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
@@ -113,6 +116,9 @@ func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-parentCtx.Done():
|
||||
cancel()
|
||||
return
|
||||
case <-ticker.C:
|
||||
if config.IsShouldShutdown() {
|
||||
cancel()
|
||||
@@ -213,6 +219,15 @@ func (uc *RestoreMariadbBackupUsecase) executeMariadbRestore(
|
||||
waitErr := cmd.Wait()
|
||||
stderrOutput := <-stderrCh
|
||||
|
||||
// Check for cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if errors.Is(ctx.Err(), context.Canceled) {
|
||||
return fmt.Errorf("restore cancelled")
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
return fmt.Errorf("restore cancelled due to shutdown")
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
"databasus-backend/internal/features/databases"
|
||||
mongodbtypes "databasus-backend/internal/features/databases/databases/mongodb"
|
||||
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
|
||||
"databasus-backend/internal/features/restores/models"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/tools"
|
||||
@@ -36,10 +36,11 @@ type RestoreMongodbBackupUsecase struct {
|
||||
}
|
||||
|
||||
func (uc *RestoreMongodbBackupUsecase) Execute(
|
||||
parentCtx context.Context,
|
||||
originalDB *databases.Database,
|
||||
restoringToDB *databases.Database,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore models.Restore,
|
||||
restore restores_core.Restore,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
) error {
|
||||
@@ -76,6 +77,7 @@ func (uc *RestoreMongodbBackupUsecase) Execute(
|
||||
args := uc.buildMongorestoreArgs(mdb, decryptedPassword, sourceDatabase)
|
||||
|
||||
return uc.restoreFromStorage(
|
||||
parentCtx,
|
||||
tools.GetMongodbExecutable(
|
||||
tools.MongodbExecutableMongorestore,
|
||||
config.GetEnv().EnvMode,
|
||||
@@ -122,12 +124,13 @@ func (uc *RestoreMongodbBackupUsecase) buildMongorestoreArgs(
|
||||
}
|
||||
|
||||
func (uc *RestoreMongodbBackupUsecase) restoreFromStorage(
|
||||
parentCtx context.Context,
|
||||
mongorestoreBin string,
|
||||
args []string,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), restoreTimeout)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, restoreTimeout)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
@@ -137,6 +140,9 @@ func (uc *RestoreMongodbBackupUsecase) restoreFromStorage(
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-parentCtx.Done():
|
||||
cancel()
|
||||
return
|
||||
case <-ticker.C:
|
||||
if config.IsShouldShutdown() {
|
||||
cancel()
|
||||
@@ -218,6 +224,15 @@ func (uc *RestoreMongodbBackupUsecase) executeMongoRestore(
|
||||
waitErr := cmd.Wait()
|
||||
stderrOutput := <-stderrCh
|
||||
|
||||
// Check for cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if errors.Is(ctx.Err(), context.Canceled) {
|
||||
return fmt.Errorf("restore cancelled")
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
return fmt.Errorf("restore cancelled due to shutdown")
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ import (
|
||||
"databasus-backend/internal/features/databases"
|
||||
mysqltypes "databasus-backend/internal/features/databases/databases/mysql"
|
||||
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
|
||||
"databasus-backend/internal/features/restores/models"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/tools"
|
||||
@@ -36,10 +36,11 @@ type RestoreMysqlBackupUsecase struct {
|
||||
}
|
||||
|
||||
func (uc *RestoreMysqlBackupUsecase) Execute(
|
||||
parentCtx context.Context,
|
||||
originalDB *databases.Database,
|
||||
restoringToDB *databases.Database,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore models.Restore,
|
||||
restore restores_core.Restore,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
) error {
|
||||
@@ -78,6 +79,7 @@ func (uc *RestoreMysqlBackupUsecase) Execute(
|
||||
}
|
||||
|
||||
return uc.restoreFromStorage(
|
||||
parentCtx,
|
||||
originalDB,
|
||||
tools.GetMysqlExecutable(
|
||||
my.Version,
|
||||
@@ -94,6 +96,7 @@ func (uc *RestoreMysqlBackupUsecase) Execute(
|
||||
}
|
||||
|
||||
func (uc *RestoreMysqlBackupUsecase) restoreFromStorage(
|
||||
parentCtx context.Context,
|
||||
database *databases.Database,
|
||||
mysqlBin string,
|
||||
args []string,
|
||||
@@ -102,7 +105,7 @@ func (uc *RestoreMysqlBackupUsecase) restoreFromStorage(
|
||||
storage *storages.Storage,
|
||||
myConfig *mysqltypes.MysqlDatabase,
|
||||
) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
@@ -112,6 +115,9 @@ func (uc *RestoreMysqlBackupUsecase) restoreFromStorage(
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-parentCtx.Done():
|
||||
cancel()
|
||||
return
|
||||
case <-ticker.C:
|
||||
if config.IsShouldShutdown() {
|
||||
cancel()
|
||||
@@ -204,6 +210,15 @@ func (uc *RestoreMysqlBackupUsecase) executeMysqlRestore(
|
||||
waitErr := cmd.Wait()
|
||||
stderrOutput := <-stderrCh
|
||||
|
||||
// Check for cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if errors.Is(ctx.Err(), context.Canceled) {
|
||||
return fmt.Errorf("restore cancelled")
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
return fmt.Errorf("restore cancelled due to shutdown")
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
"databasus-backend/internal/features/databases"
|
||||
pgtypes "databasus-backend/internal/features/databases/databases/postgresql"
|
||||
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
|
||||
"databasus-backend/internal/features/restores/models"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/tools"
|
||||
@@ -35,10 +35,11 @@ type RestorePostgresqlBackupUsecase struct {
|
||||
}
|
||||
|
||||
func (uc *RestorePostgresqlBackupUsecase) Execute(
|
||||
parentCtx context.Context,
|
||||
originalDB *databases.Database,
|
||||
restoringToDB *databases.Database,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore models.Restore,
|
||||
restore restores_core.Restore,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
isExcludeExtensions bool,
|
||||
@@ -73,6 +74,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
|
||||
|
||||
// All PostgreSQL backups are now custom format (-Fc)
|
||||
return uc.restoreCustomType(
|
||||
parentCtx,
|
||||
originalDB,
|
||||
pgBin,
|
||||
backup,
|
||||
@@ -84,6 +86,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
|
||||
|
||||
// restoreCustomType restores a backup in custom type (-Fc)
|
||||
func (uc *RestorePostgresqlBackupUsecase) restoreCustomType(
|
||||
parentCtx context.Context,
|
||||
originalDB *databases.Database,
|
||||
pgBin string,
|
||||
backup *backups_core.Backup,
|
||||
@@ -102,15 +105,24 @@ func (uc *RestorePostgresqlBackupUsecase) restoreCustomType(
|
||||
// If excluding extensions, we must use file-based restore (requires TOC file generation)
|
||||
// Also use file-based restore for parallel jobs (multiple CPUs)
|
||||
if isExcludeExtensions || pg.CpuCount > 1 {
|
||||
return uc.restoreViaFile(originalDB, pgBin, backup, storage, pg, isExcludeExtensions)
|
||||
return uc.restoreViaFile(
|
||||
parentCtx,
|
||||
originalDB,
|
||||
pgBin,
|
||||
backup,
|
||||
storage,
|
||||
pg,
|
||||
isExcludeExtensions,
|
||||
)
|
||||
}
|
||||
|
||||
// Single CPU without extension exclusion: stream directly via stdin
|
||||
return uc.restoreViaStdin(originalDB, pgBin, backup, storage, pg)
|
||||
return uc.restoreViaStdin(parentCtx, originalDB, pgBin, backup, storage, pg)
|
||||
}
|
||||
|
||||
// restoreViaStdin streams backup via stdin for single CPU restore
|
||||
func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
|
||||
parentCtx context.Context,
|
||||
originalDB *databases.Database,
|
||||
pgBin string,
|
||||
backup *backups_core.Backup,
|
||||
@@ -133,10 +145,10 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
|
||||
"--no-acl",
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// Monitor for shutdown and cancel context if needed
|
||||
// Monitor for shutdown and parent cancellation
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
@@ -145,6 +157,9 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-parentCtx.Done():
|
||||
cancel()
|
||||
return
|
||||
case <-ticker.C:
|
||||
if config.IsShouldShutdown() {
|
||||
cancel()
|
||||
@@ -296,6 +311,15 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
|
||||
stderrOutput := <-stderrCh
|
||||
copyErr := <-copyErrCh
|
||||
|
||||
// Check for cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if errors.Is(ctx.Err(), context.Canceled) {
|
||||
return fmt.Errorf("restore cancelled")
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
// Check for shutdown before finalizing
|
||||
if config.IsShouldShutdown() {
|
||||
return fmt.Errorf("restore cancelled due to shutdown")
|
||||
@@ -307,6 +331,15 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
|
||||
}
|
||||
|
||||
if waitErr != nil {
|
||||
// Check for cancellation again
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if errors.Is(ctx.Err(), context.Canceled) {
|
||||
return fmt.Errorf("restore cancelled")
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
return fmt.Errorf("restore cancelled due to shutdown")
|
||||
}
|
||||
@@ -319,6 +352,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
|
||||
|
||||
// restoreViaFile downloads backup and uses parallel jobs for multi-CPU restore
|
||||
func (uc *RestorePostgresqlBackupUsecase) restoreViaFile(
|
||||
parentCtx context.Context,
|
||||
originalDB *databases.Database,
|
||||
pgBin string,
|
||||
backup *backups_core.Backup,
|
||||
@@ -354,6 +388,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaFile(
|
||||
}
|
||||
|
||||
return uc.restoreFromStorage(
|
||||
parentCtx,
|
||||
originalDB,
|
||||
pgBin,
|
||||
args,
|
||||
@@ -367,6 +402,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaFile(
|
||||
|
||||
// restoreFromStorage restores backup data from storage using pg_restore
|
||||
func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
|
||||
parentCtx context.Context,
|
||||
database *databases.Database,
|
||||
pgBin string,
|
||||
args []string,
|
||||
@@ -386,10 +422,10 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
|
||||
isExcludeExtensions,
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// Monitor for shutdown and cancel context if needed
|
||||
// Monitor for shutdown and parent cancellation
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
@@ -398,6 +434,9 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-parentCtx.Done():
|
||||
cancel()
|
||||
return
|
||||
case <-ticker.C:
|
||||
if config.IsShouldShutdown() {
|
||||
cancel()
|
||||
@@ -624,12 +663,30 @@ func (uc *RestorePostgresqlBackupUsecase) executePgRestore(
|
||||
waitErr := cmd.Wait()
|
||||
stderrOutput := <-stderrCh
|
||||
|
||||
// Check for cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if errors.Is(ctx.Err(), context.Canceled) {
|
||||
return fmt.Errorf("restore cancelled")
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
// Check for shutdown before finalizing
|
||||
if config.IsShouldShutdown() {
|
||||
return fmt.Errorf("restore cancelled due to shutdown")
|
||||
}
|
||||
|
||||
if waitErr != nil {
|
||||
// Check for cancellation again
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if errors.Is(ctx.Err(), context.Canceled) {
|
||||
return fmt.Errorf("restore cancelled")
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
return fmt.Errorf("restore cancelled due to shutdown")
|
||||
}
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
package usecases
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/restores/models"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
usecases_mariadb "databasus-backend/internal/features/restores/usecases/mariadb"
|
||||
usecases_mongodb "databasus-backend/internal/features/restores/usecases/mongodb"
|
||||
usecases_mysql "databasus-backend/internal/features/restores/usecases/mysql"
|
||||
@@ -22,8 +23,9 @@ type RestoreBackupUsecase struct {
|
||||
}
|
||||
|
||||
func (uc *RestoreBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore models.Restore,
|
||||
restore restores_core.Restore,
|
||||
originalDB *databases.Database,
|
||||
restoringToDB *databases.Database,
|
||||
backup *backups_core.Backup,
|
||||
@@ -33,6 +35,7 @@ func (uc *RestoreBackupUsecase) Execute(
|
||||
switch originalDB.Type {
|
||||
case databases.DatabaseTypePostgres:
|
||||
return uc.restorePostgresqlBackupUsecase.Execute(
|
||||
ctx,
|
||||
originalDB,
|
||||
restoringToDB,
|
||||
backupConfig,
|
||||
@@ -43,6 +46,7 @@ func (uc *RestoreBackupUsecase) Execute(
|
||||
)
|
||||
case databases.DatabaseTypeMysql:
|
||||
return uc.restoreMysqlBackupUsecase.Execute(
|
||||
ctx,
|
||||
originalDB,
|
||||
restoringToDB,
|
||||
backupConfig,
|
||||
@@ -52,6 +56,7 @@ func (uc *RestoreBackupUsecase) Execute(
|
||||
)
|
||||
case databases.DatabaseTypeMariadb:
|
||||
return uc.restoreMariadbBackupUsecase.Execute(
|
||||
ctx,
|
||||
originalDB,
|
||||
restoringToDB,
|
||||
backupConfig,
|
||||
@@ -61,6 +66,7 @@ func (uc *RestoreBackupUsecase) Execute(
|
||||
)
|
||||
case databases.DatabaseTypeMongodb:
|
||||
return uc.restoreMongodbBackupUsecase.Execute(
|
||||
ctx,
|
||||
originalDB,
|
||||
restoringToDB,
|
||||
backupConfig,
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
azure_blob_storage "databasus-backend/internal/features/storages/models/azure_blob"
|
||||
ftp_storage "databasus-backend/internal/features/storages/models/ftp"
|
||||
@@ -902,6 +903,12 @@ func Test_StorageSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Skip Google Drive tests if external resources tests are disabled
|
||||
if tc.storageType == StorageTypeGoogleDrive &&
|
||||
config.GetEnv().IsSkipExternalResourcesTests {
|
||||
t.Skip("Skipping Google Drive storage test: IS_SKIP_EXTERNAL_RESOURCES_TESTS=true")
|
||||
}
|
||||
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
package storages
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var storageRepository = &StorageRepository{}
|
||||
@@ -27,6 +31,21 @@ func GetStorageController() *StorageController {
|
||||
return storageController
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(storageService)
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(storageService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,7 +113,7 @@ func Test_Storage_BasicOperations(t *testing.T) {
|
||||
name: "NASStorage",
|
||||
storage: &nas_storage.NASStorage{
|
||||
StorageID: uuid.New(),
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: nasPort,
|
||||
Share: "backups",
|
||||
Username: "testuser",
|
||||
@@ -147,7 +147,7 @@ func Test_Storage_BasicOperations(t *testing.T) {
|
||||
name: "FTPStorage",
|
||||
storage: &ftp_storage.FTPStorage{
|
||||
StorageID: uuid.New(),
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: ftpPort,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
@@ -159,7 +159,7 @@ func Test_Storage_BasicOperations(t *testing.T) {
|
||||
name: "SFTPStorage",
|
||||
storage: &sftp_storage.SFTPStorage{
|
||||
StorageID: uuid.New(),
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: sftpPort,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
@@ -185,7 +185,9 @@ acl = private`, s3Container.accessKey, s3Container.secretKey, s3Container.endpoi
|
||||
|
||||
// Add Google Drive storage test only if environment variables are available
|
||||
env := config.GetEnv()
|
||||
if env.TestGoogleDriveClientID != "" && env.TestGoogleDriveClientSecret != "" &&
|
||||
if env.IsSkipExternalResourcesTests {
|
||||
t.Log("Skipping Google Drive storage test: IS_SKIP_EXTERNAL_RESOURCES_TESTS=true")
|
||||
} else if env.TestGoogleDriveClientID != "" && env.TestGoogleDriveClientSecret != "" &&
|
||||
env.TestGoogleDriveTokenJSON != "" {
|
||||
testCases = append(testCases, struct {
|
||||
name string
|
||||
@@ -297,7 +299,7 @@ func setupS3Container(ctx context.Context) (*S3Container, error) {
|
||||
secretKey := "testpassword"
|
||||
bucketName := "test-bucket"
|
||||
region := "us-east-1"
|
||||
endpoint := fmt.Sprintf("127.0.0.1:%s", env.TestMinioPort)
|
||||
endpoint := fmt.Sprintf("%s:%s", env.TestLocalhost, env.TestMinioPort)
|
||||
|
||||
// Create MinIO client and ensure bucket exists
|
||||
minioClient, err := minio.New(endpoint, &minio.Options{
|
||||
@@ -343,15 +345,21 @@ func setupAzuriteContainer(ctx context.Context) (*AzuriteContainer, error) {
|
||||
accountName := "devstoreaccount1"
|
||||
// this is real testing key for azurite, it's not a real key
|
||||
accountKey := "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="
|
||||
serviceURL := fmt.Sprintf("http://127.0.0.1:%s/%s", env.TestAzuriteBlobPort, accountName)
|
||||
serviceURL := fmt.Sprintf(
|
||||
"http://%s:%s/%s",
|
||||
env.TestLocalhost,
|
||||
env.TestAzuriteBlobPort,
|
||||
accountName,
|
||||
)
|
||||
containerNameKey := "test-container-key"
|
||||
containerNameStr := "test-container-connstr"
|
||||
|
||||
// Build explicit connection string for Azurite
|
||||
connectionString := fmt.Sprintf(
|
||||
"DefaultEndpointsProtocol=http;AccountName=%s;AccountKey=%s;BlobEndpoint=http://127.0.0.1:%s/%s",
|
||||
"DefaultEndpointsProtocol=http;AccountName=%s;AccountKey=%s;BlobEndpoint=http://%s:%s/%s",
|
||||
accountName,
|
||||
accountKey,
|
||||
env.TestLocalhost,
|
||||
env.TestAzuriteBlobPort,
|
||||
accountName,
|
||||
)
|
||||
|
||||
@@ -285,30 +285,30 @@ func (f *FTPStorage) ensureDirectory(conn *ftp.ServerConn, path string) error {
|
||||
}
|
||||
|
||||
parts := strings.Split(path, "/")
|
||||
currentPath := ""
|
||||
|
||||
currentDir, err := conn.CurrentDir()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get current directory: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.ChangeDir(currentDir)
|
||||
}()
|
||||
|
||||
for _, part := range parts {
|
||||
if part == "" || part == "." {
|
||||
continue
|
||||
}
|
||||
|
||||
if currentPath == "" {
|
||||
currentPath = part
|
||||
} else {
|
||||
currentPath = currentPath + "/" + part
|
||||
}
|
||||
|
||||
err := conn.ChangeDir(currentPath)
|
||||
err := conn.ChangeDir(part)
|
||||
if err != nil {
|
||||
err = conn.MakeDir(currentPath)
|
||||
err = conn.MakeDir(part)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create directory '%s': %w", currentPath, err)
|
||||
return fmt.Errorf("failed to create directory '%s': %w", part, err)
|
||||
}
|
||||
err = conn.ChangeDir(part)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to change into directory '%s': %w", part, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = conn.ChangeDirToParent()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to change to parent directory: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ func (s *HealthcheckService) IsHealthy() error {
|
||||
}
|
||||
}
|
||||
|
||||
if config.GetEnv().IsBackupNode {
|
||||
if config.GetEnv().IsProcessingNode {
|
||||
if !s.backuperNode.IsBackuperRunning() {
|
||||
return errors.New("backuper node is not running for more than 5 minutes")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
package task_cancellation
|
||||
|
||||
import (
|
||||
"context"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const taskCancelChannel = "task:cancel"
|
||||
|
||||
type TaskCancelManager struct {
|
||||
mu sync.RWMutex
|
||||
cancelFuncs map[uuid.UUID]context.CancelFunc
|
||||
pubsub *cache_utils.PubSubManager
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (m *TaskCancelManager) StartSubscription() {
|
||||
ctx := context.Background()
|
||||
|
||||
handler := func(message string) {
|
||||
taskID, err := uuid.Parse(message)
|
||||
if err != nil {
|
||||
m.logger.Error("Invalid task ID in cancel message", "message", message, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
cancelFunc, exists := m.cancelFuncs[taskID]
|
||||
if exists {
|
||||
cancelFunc()
|
||||
delete(m.cancelFuncs, taskID)
|
||||
m.logger.Info("Cancelled task via Pub/Sub", "taskID", taskID)
|
||||
}
|
||||
}
|
||||
|
||||
err := m.pubsub.Subscribe(ctx, taskCancelChannel, handler)
|
||||
if err != nil {
|
||||
m.logger.Error("Failed to subscribe to task cancel channel", "error", err)
|
||||
} else {
|
||||
m.logger.Info("Successfully subscribed to task cancel channel")
|
||||
}
|
||||
}
|
||||
|
||||
func (m *TaskCancelManager) RegisterTask(task uuid.UUID, cancelFunc context.CancelFunc) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.cancelFuncs[task] = cancelFunc
|
||||
m.logger.Debug("Registered task", "taskID", task)
|
||||
}
|
||||
|
||||
func (m *TaskCancelManager) CancelTask(taskID uuid.UUID) error {
|
||||
ctx := context.Background()
|
||||
|
||||
err := m.pubsub.Publish(ctx, taskCancelChannel, taskID.String())
|
||||
if err != nil {
|
||||
m.logger.Error("Failed to publish cancel message", "taskID", taskID, "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
m.logger.Info("Published task cancel message", "taskID", taskID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *TaskCancelManager) UnregisterTask(taskID uuid.UUID) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.cancelFuncs, taskID)
|
||||
m.logger.Debug("Unregistered task", "taskID", taskID)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package backups_cancellation
|
||||
package task_cancellation
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -10,41 +10,41 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_RegisterBackup_BackupRegisteredSuccessfully(t *testing.T) {
|
||||
manager := backupCancelManager
|
||||
func Test_RegisterTask_TaskRegisteredSuccessfully(t *testing.T) {
|
||||
manager := taskCancelManager
|
||||
|
||||
backupID := uuid.New()
|
||||
taskID := uuid.New()
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
manager.RegisterBackup(backupID, cancel)
|
||||
manager.RegisterTask(taskID, cancel)
|
||||
|
||||
manager.mu.RLock()
|
||||
_, exists := manager.cancelFuncs[backupID]
|
||||
_, exists := manager.cancelFuncs[taskID]
|
||||
manager.mu.RUnlock()
|
||||
assert.True(t, exists, "Backup should be registered")
|
||||
assert.True(t, exists, "Task should be registered")
|
||||
}
|
||||
|
||||
func Test_UnregisterBackup_BackupUnregisteredSuccessfully(t *testing.T) {
|
||||
manager := backupCancelManager
|
||||
func Test_UnregisterTask_TaskUnregisteredSuccessfully(t *testing.T) {
|
||||
manager := taskCancelManager
|
||||
|
||||
backupID := uuid.New()
|
||||
taskID := uuid.New()
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
manager.RegisterBackup(backupID, cancel)
|
||||
manager.UnregisterBackup(backupID)
|
||||
manager.RegisterTask(taskID, cancel)
|
||||
manager.UnregisterTask(taskID)
|
||||
|
||||
manager.mu.RLock()
|
||||
_, exists := manager.cancelFuncs[backupID]
|
||||
_, exists := manager.cancelFuncs[taskID]
|
||||
manager.mu.RUnlock()
|
||||
assert.False(t, exists, "Backup should be unregistered")
|
||||
assert.False(t, exists, "Task should be unregistered")
|
||||
}
|
||||
|
||||
func Test_CancelBackup_OnSameInstance_BackupCancelledViaPubSub(t *testing.T) {
|
||||
manager := backupCancelManager
|
||||
func Test_CancelTask_OnSameInstance_TaskCancelledViaPubSub(t *testing.T) {
|
||||
manager := taskCancelManager
|
||||
|
||||
backupID := uuid.New()
|
||||
taskID := uuid.New()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
cancelled := false
|
||||
@@ -57,11 +57,11 @@ func Test_CancelBackup_OnSameInstance_BackupCancelledViaPubSub(t *testing.T) {
|
||||
cancel()
|
||||
}
|
||||
|
||||
manager.RegisterBackup(backupID, wrappedCancel)
|
||||
manager.RegisterTask(taskID, wrappedCancel)
|
||||
manager.StartSubscription()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err := manager.CancelBackup(backupID)
|
||||
err := manager.CancelTask(taskID)
|
||||
assert.NoError(t, err, "Cancel should not return error")
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
@@ -74,11 +74,11 @@ func Test_CancelBackup_OnSameInstance_BackupCancelledViaPubSub(t *testing.T) {
|
||||
assert.Error(t, ctx.Err(), "Context should be cancelled")
|
||||
}
|
||||
|
||||
func Test_CancelBackup_FromDifferentInstance_BackupCancelledOnRunningInstance(t *testing.T) {
|
||||
manager1 := backupCancelManager
|
||||
manager2 := backupCancelManager
|
||||
func Test_CancelTask_FromDifferentInstance_TaskCancelledOnRunningInstance(t *testing.T) {
|
||||
manager1 := taskCancelManager
|
||||
manager2 := taskCancelManager
|
||||
|
||||
backupID := uuid.New()
|
||||
taskID := uuid.New()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
cancelled := false
|
||||
@@ -91,13 +91,13 @@ func Test_CancelBackup_FromDifferentInstance_BackupCancelledOnRunningInstance(t
|
||||
cancel()
|
||||
}
|
||||
|
||||
manager1.RegisterBackup(backupID, wrappedCancel)
|
||||
manager1.RegisterTask(taskID, wrappedCancel)
|
||||
|
||||
manager1.StartSubscription()
|
||||
manager2.StartSubscription()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err := manager2.CancelBackup(backupID)
|
||||
err := manager2.CancelTask(taskID)
|
||||
assert.NoError(t, err, "Cancel should not return error")
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
@@ -110,29 +110,29 @@ func Test_CancelBackup_FromDifferentInstance_BackupCancelledOnRunningInstance(t
|
||||
assert.Error(t, ctx.Err(), "Context should be cancelled")
|
||||
}
|
||||
|
||||
func Test_CancelBackup_WhenBackupDoesNotExist_NoErrorReturned(t *testing.T) {
|
||||
manager := backupCancelManager
|
||||
func Test_CancelTask_WhenTaskDoesNotExist_NoErrorReturned(t *testing.T) {
|
||||
manager := taskCancelManager
|
||||
|
||||
manager.StartSubscription()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
nonExistentID := uuid.New()
|
||||
err := manager.CancelBackup(nonExistentID)
|
||||
assert.NoError(t, err, "Cancelling non-existent backup should not error")
|
||||
err := manager.CancelTask(nonExistentID)
|
||||
assert.NoError(t, err, "Cancelling non-existent task should not error")
|
||||
}
|
||||
|
||||
func Test_CancelBackup_WithMultipleBackups_AllBackupsCancelled(t *testing.T) {
|
||||
manager := backupCancelManager
|
||||
func Test_CancelTask_WithMultipleTasks_AllTasksCancelled(t *testing.T) {
|
||||
manager := taskCancelManager
|
||||
|
||||
numBackups := 5
|
||||
backupIDs := make([]uuid.UUID, numBackups)
|
||||
contexts := make([]context.Context, numBackups)
|
||||
cancels := make([]context.CancelFunc, numBackups)
|
||||
cancelledFlags := make([]bool, numBackups)
|
||||
numTasks := 5
|
||||
taskIDs := make([]uuid.UUID, numTasks)
|
||||
contexts := make([]context.Context, numTasks)
|
||||
cancels := make([]context.CancelFunc, numTasks)
|
||||
cancelledFlags := make([]bool, numTasks)
|
||||
var mu sync.Mutex
|
||||
|
||||
for i := 0; i < numBackups; i++ {
|
||||
backupIDs[i] = uuid.New()
|
||||
for i := 0; i < numTasks; i++ {
|
||||
taskIDs[i] = uuid.New()
|
||||
contexts[i], cancels[i] = context.WithCancel(context.Background())
|
||||
|
||||
idx := i
|
||||
@@ -143,31 +143,31 @@ func Test_CancelBackup_WithMultipleBackups_AllBackupsCancelled(t *testing.T) {
|
||||
cancels[idx]()
|
||||
}
|
||||
|
||||
manager.RegisterBackup(backupIDs[i], wrappedCancel)
|
||||
manager.RegisterTask(taskIDs[i], wrappedCancel)
|
||||
}
|
||||
|
||||
manager.StartSubscription()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
for i := 0; i < numBackups; i++ {
|
||||
err := manager.CancelBackup(backupIDs[i])
|
||||
for i := 0; i < numTasks; i++ {
|
||||
err := manager.CancelTask(taskIDs[i])
|
||||
assert.NoError(t, err, "Cancel should not return error")
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
mu.Lock()
|
||||
for i := 0; i < numBackups; i++ {
|
||||
assert.True(t, cancelledFlags[i], "Backup %d should be cancelled", i)
|
||||
for i := 0; i < numTasks; i++ {
|
||||
assert.True(t, cancelledFlags[i], "Task %d should be cancelled", i)
|
||||
assert.Error(t, contexts[i].Err(), "Context %d should be cancelled", i)
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
func Test_CancelBackup_AfterUnregister_BackupNotCancelled(t *testing.T) {
|
||||
manager := backupCancelManager
|
||||
func Test_CancelTask_AfterUnregister_TaskNotCancelled(t *testing.T) {
|
||||
manager := taskCancelManager
|
||||
|
||||
backupID := uuid.New()
|
||||
taskID := uuid.New()
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
@@ -181,13 +181,13 @@ func Test_CancelBackup_AfterUnregister_BackupNotCancelled(t *testing.T) {
|
||||
cancel()
|
||||
}
|
||||
|
||||
manager.RegisterBackup(backupID, wrappedCancel)
|
||||
manager.RegisterTask(taskID, wrappedCancel)
|
||||
manager.StartSubscription()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
manager.UnregisterBackup(backupID)
|
||||
manager.UnregisterTask(taskID)
|
||||
|
||||
err := manager.CancelBackup(backupID)
|
||||
err := manager.CancelTask(taskID)
|
||||
assert.NoError(t, err, "Cancel should not return error")
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
42
backend/internal/features/tasks/cancellation/di.go
Normal file
42
backend/internal/features/tasks/cancellation/di.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package task_cancellation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/logger"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var taskCancelManager = &TaskCancelManager{
|
||||
sync.RWMutex{},
|
||||
make(map[uuid.UUID]context.CancelFunc),
|
||||
cache_utils.NewPubSubManager(),
|
||||
logger.GetLogger(),
|
||||
}
|
||||
|
||||
func GetTaskCancelManager() *TaskCancelManager {
|
||||
return taskCancelManager
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
taskCancelManager.StartSubscription()
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
159
backend/internal/features/test_once_protection.go
Normal file
159
backend/internal/features/test_once_protection.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package features
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/restores"
|
||||
"databasus-backend/internal/features/restores/restoring"
|
||||
"databasus-backend/internal/features/storages"
|
||||
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
)
|
||||
|
||||
// Test_SetupDependencies_CalledTwice_LogsWarning verifies SetupDependencies is idempotent
|
||||
func Test_SetupDependencies_CalledTwice_LogsWarning(t *testing.T) {
|
||||
// Call each SetupDependencies twice - should not panic, only log warnings
|
||||
audit_logs.SetupDependencies()
|
||||
audit_logs.SetupDependencies()
|
||||
|
||||
backups.SetupDependencies()
|
||||
backups.SetupDependencies()
|
||||
|
||||
backups_config.SetupDependencies()
|
||||
backups_config.SetupDependencies()
|
||||
|
||||
databases.SetupDependencies()
|
||||
databases.SetupDependencies()
|
||||
|
||||
healthcheck_config.SetupDependencies()
|
||||
healthcheck_config.SetupDependencies()
|
||||
|
||||
notifiers.SetupDependencies()
|
||||
notifiers.SetupDependencies()
|
||||
|
||||
restores.SetupDependencies()
|
||||
restores.SetupDependencies()
|
||||
|
||||
storages.SetupDependencies()
|
||||
storages.SetupDependencies()
|
||||
|
||||
task_cancellation.SetupDependencies()
|
||||
task_cancellation.SetupDependencies()
|
||||
|
||||
// If we reach here without panic, test passes
|
||||
t.Log("All SetupDependencies calls completed successfully (idempotent)")
|
||||
}
|
||||
|
||||
// Test_SetupDependencies_ConcurrentCalls_Safe verifies thread safety
|
||||
func Test_SetupDependencies_ConcurrentCalls_Safe(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Call SetupDependencies concurrently from 10 goroutines
|
||||
for range 10 {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
audit_logs.SetupDependencies()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
t.Log("Concurrent SetupDependencies calls completed successfully")
|
||||
}
|
||||
|
||||
// Test_BackgroundService_Run_CalledTwice_Panics verifies Run() panics on duplicate calls
|
||||
func Test_BackgroundService_Run_CalledTwice_Panics(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Create a test background service
|
||||
backgroundService := audit_logs.GetAuditLogBackgroundService()
|
||||
|
||||
// Start first Run() in goroutine
|
||||
go func() {
|
||||
backgroundService.Run(ctx)
|
||||
}()
|
||||
|
||||
// Give first call time to initialize
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Second call should panic
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
expectedMsg := "*audit_logs.AuditLogBackgroundService.Run() called multiple times"
|
||||
panicMsg := fmt.Sprintf("%v", r)
|
||||
if panicMsg == expectedMsg {
|
||||
t.Logf("Successfully caught panic: %v", r)
|
||||
} else {
|
||||
t.Errorf("Expected panic message '%s', got '%s'", expectedMsg, panicMsg)
|
||||
}
|
||||
} else {
|
||||
t.Error("Expected panic on second Run() call, but did not panic")
|
||||
}
|
||||
}()
|
||||
|
||||
backgroundService.Run(ctx)
|
||||
}
|
||||
|
||||
// Test_BackupsScheduler_Run_CalledTwice_Panics verifies scheduler panics on duplicate calls
|
||||
func Test_BackupsScheduler_Run_CalledTwice_Panics(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
scheduler := backuping.GetBackupsScheduler()
|
||||
|
||||
// Start first Run() in goroutine
|
||||
go func() {
|
||||
scheduler.Run(ctx)
|
||||
}()
|
||||
|
||||
// Give first call time to initialize
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Second call should panic
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Logf("Successfully caught panic: %v", r)
|
||||
} else {
|
||||
t.Error("Expected panic on second Run() call, but did not panic")
|
||||
}
|
||||
}()
|
||||
|
||||
scheduler.Run(ctx)
|
||||
}
|
||||
|
||||
// Test_RestoresScheduler_Run_CalledTwice_Panics verifies restore scheduler panics on duplicate calls
|
||||
func Test_RestoresScheduler_Run_CalledTwice_Panics(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
|
||||
scheduler := restoring.GetRestoresScheduler()
|
||||
|
||||
// Start first Run() in goroutine
|
||||
go func() {
|
||||
scheduler.Run(ctx)
|
||||
}()
|
||||
|
||||
// Give first call time to initialize
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Second call should panic
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Logf("Successfully caught panic: %v", r)
|
||||
} else {
|
||||
t.Error("Expected panic on second Run() call, but did not panic")
|
||||
}
|
||||
}()
|
||||
|
||||
scheduler.Run(ctx)
|
||||
}
|
||||
@@ -21,9 +21,7 @@ import (
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
mariadbtypes "databasus-backend/internal/features/databases/databases/mariadb"
|
||||
"databasus-backend/internal/features/restores"
|
||||
restores_enums "databasus-backend/internal/features/restores/enums"
|
||||
restores_models "databasus-backend/internal/features/restores/models"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
@@ -213,7 +211,7 @@ func testMariadbBackupRestoreForVersion(
|
||||
)
|
||||
|
||||
restore := waitForMariadbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var tableExists int
|
||||
err = newDB.Get(
|
||||
@@ -311,7 +309,7 @@ func testMariadbBackupRestoreWithEncryptionForVersion(
|
||||
)
|
||||
|
||||
restore := waitForMariadbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var tableExists int
|
||||
err = newDB.Get(
|
||||
@@ -418,7 +416,7 @@ func testMariadbBackupRestoreWithReadOnlyUserForVersion(
|
||||
)
|
||||
|
||||
restore := waitForMariadbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var tableExists int
|
||||
err = newDB.Get(
|
||||
@@ -506,7 +504,7 @@ func createMariadbRestoreViaAPI(
|
||||
version tools.MariadbVersion,
|
||||
token string,
|
||||
) {
|
||||
request := restores.RestoreBackupRequest{
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
MariadbDatabase: &mariadbtypes.MariadbDatabase{
|
||||
Host: host,
|
||||
Port: port,
|
||||
@@ -533,7 +531,7 @@ func waitForMariadbRestoreCompletion(
|
||||
backupID uuid.UUID,
|
||||
token string,
|
||||
timeout time.Duration,
|
||||
) *restores_models.Restore {
|
||||
) *restores_core.Restore {
|
||||
startTime := time.Now()
|
||||
pollInterval := 500 * time.Millisecond
|
||||
|
||||
@@ -542,7 +540,7 @@ func waitForMariadbRestoreCompletion(
|
||||
t.Fatalf("Timeout waiting for MariaDB restore completion after %v", timeout)
|
||||
}
|
||||
|
||||
var restoresList []*restores_models.Restore
|
||||
var restoresList []*restores_core.Restore
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
@@ -553,10 +551,10 @@ func waitForMariadbRestoreCompletion(
|
||||
)
|
||||
|
||||
for _, restore := range restoresList {
|
||||
if restore.Status == restores_enums.RestoreStatusCompleted {
|
||||
if restore.Status == restores_core.RestoreStatusCompleted {
|
||||
return restore
|
||||
}
|
||||
if restore.Status == restores_enums.RestoreStatusFailed {
|
||||
if restore.Status == restores_core.RestoreStatusFailed {
|
||||
failMsg := "unknown error"
|
||||
if restore.FailMessage != nil {
|
||||
failMsg = *restore.FailMessage
|
||||
@@ -607,7 +605,7 @@ func connectToMariadbContainer(
|
||||
dbName := "testdb"
|
||||
password := "rootpassword"
|
||||
username := "root"
|
||||
host := "127.0.0.1"
|
||||
host := config.GetEnv().TestLocalhost
|
||||
|
||||
portInt, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
|
||||
@@ -23,9 +23,7 @@ import (
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
mongodbtypes "databasus-backend/internal/features/databases/databases/mongodb"
|
||||
"databasus-backend/internal/features/restores"
|
||||
restores_enums "databasus-backend/internal/features/restores/enums"
|
||||
restores_models "databasus-backend/internal/features/restores/models"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
@@ -175,7 +173,7 @@ func testMongodbBackupRestoreForVersion(
|
||||
)
|
||||
|
||||
restore := waitForMongodbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
verifyMongodbDataIntegrity(t, container, newDBName)
|
||||
|
||||
@@ -254,7 +252,7 @@ func testMongodbBackupRestoreWithEncryptionForVersion(
|
||||
)
|
||||
|
||||
restore := waitForMongodbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
verifyMongodbDataIntegrity(t, container, newDBName)
|
||||
|
||||
@@ -342,7 +340,7 @@ func testMongodbBackupRestoreWithReadOnlyUserForVersion(
|
||||
)
|
||||
|
||||
restore := waitForMongodbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
verifyMongodbDataIntegrity(t, container, newDBName)
|
||||
|
||||
@@ -431,7 +429,7 @@ func createMongodbRestoreViaAPI(
|
||||
version tools.MongodbVersion,
|
||||
token string,
|
||||
) {
|
||||
request := restores.RestoreBackupRequest{
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
MongodbDatabase: &mongodbtypes.MongodbDatabase{
|
||||
Host: host,
|
||||
Port: port,
|
||||
@@ -461,7 +459,7 @@ func waitForMongodbRestoreCompletion(
|
||||
backupID uuid.UUID,
|
||||
token string,
|
||||
timeout time.Duration,
|
||||
) *restores_models.Restore {
|
||||
) *restores_core.Restore {
|
||||
startTime := time.Now()
|
||||
pollInterval := 500 * time.Millisecond
|
||||
|
||||
@@ -470,7 +468,7 @@ func waitForMongodbRestoreCompletion(
|
||||
t.Fatalf("Timeout waiting for MongoDB restore completion after %v", timeout)
|
||||
}
|
||||
|
||||
var restoresList []*restores_models.Restore
|
||||
var restoresList []*restores_core.Restore
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
@@ -481,10 +479,10 @@ func waitForMongodbRestoreCompletion(
|
||||
)
|
||||
|
||||
for _, restore := range restoresList {
|
||||
if restore.Status == restores_enums.RestoreStatusCompleted {
|
||||
if restore.Status == restores_core.RestoreStatusCompleted {
|
||||
return restore
|
||||
}
|
||||
if restore.Status == restores_enums.RestoreStatusFailed {
|
||||
if restore.Status == restores_core.RestoreStatusFailed {
|
||||
failMsg := "unknown error"
|
||||
if restore.FailMessage != nil {
|
||||
failMsg = *restore.FailMessage
|
||||
@@ -552,7 +550,7 @@ func connectToMongodbContainer(
|
||||
password := "rootpassword"
|
||||
username := "root"
|
||||
authDatabase := "admin"
|
||||
host := "127.0.0.1"
|
||||
host := config.GetEnv().TestLocalhost
|
||||
|
||||
portInt, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
@@ -560,8 +558,13 @@ func connectToMongodbContainer(
|
||||
}
|
||||
|
||||
uri := fmt.Sprintf(
|
||||
"mongodb://%s:%s@%s:%d/%s?authSource=%s",
|
||||
username, password, host, portInt, dbName, authDatabase,
|
||||
"mongodb://%s:%s@%s:%d/%s?authSource=%s&serverSelectionTimeoutMS=5000&connectTimeoutMS=5000",
|
||||
username,
|
||||
password,
|
||||
host,
|
||||
portInt,
|
||||
dbName,
|
||||
authDatabase,
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
|
||||
@@ -21,9 +21,7 @@ import (
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
mysqltypes "databasus-backend/internal/features/databases/databases/mysql"
|
||||
"databasus-backend/internal/features/restores"
|
||||
restores_enums "databasus-backend/internal/features/restores/enums"
|
||||
restores_models "databasus-backend/internal/features/restores/models"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
@@ -188,7 +186,7 @@ func testMysqlBackupRestoreForVersion(t *testing.T, mysqlVersion tools.MysqlVers
|
||||
)
|
||||
|
||||
restore := waitForMysqlRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var tableExists int
|
||||
err = newDB.Get(
|
||||
@@ -286,7 +284,7 @@ func testMysqlBackupRestoreWithEncryptionForVersion(
|
||||
)
|
||||
|
||||
restore := waitForMysqlRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var tableExists int
|
||||
err = newDB.Get(
|
||||
@@ -393,7 +391,7 @@ func testMysqlBackupRestoreWithReadOnlyUserForVersion(
|
||||
)
|
||||
|
||||
restore := waitForMysqlRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var tableExists int
|
||||
err = newDB.Get(
|
||||
@@ -481,7 +479,7 @@ func createMysqlRestoreViaAPI(
|
||||
version tools.MysqlVersion,
|
||||
token string,
|
||||
) {
|
||||
request := restores.RestoreBackupRequest{
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
MysqlDatabase: &mysqltypes.MysqlDatabase{
|
||||
Host: host,
|
||||
Port: port,
|
||||
@@ -508,7 +506,7 @@ func waitForMysqlRestoreCompletion(
|
||||
backupID uuid.UUID,
|
||||
token string,
|
||||
timeout time.Duration,
|
||||
) *restores_models.Restore {
|
||||
) *restores_core.Restore {
|
||||
startTime := time.Now()
|
||||
pollInterval := 500 * time.Millisecond
|
||||
|
||||
@@ -517,7 +515,7 @@ func waitForMysqlRestoreCompletion(
|
||||
t.Fatalf("Timeout waiting for MySQL restore completion after %v", timeout)
|
||||
}
|
||||
|
||||
var restoresList []*restores_models.Restore
|
||||
var restoresList []*restores_core.Restore
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
@@ -528,10 +526,10 @@ func waitForMysqlRestoreCompletion(
|
||||
)
|
||||
|
||||
for _, restore := range restoresList {
|
||||
if restore.Status == restores_enums.RestoreStatusCompleted {
|
||||
if restore.Status == restores_core.RestoreStatusCompleted {
|
||||
return restore
|
||||
}
|
||||
if restore.Status == restores_enums.RestoreStatusFailed {
|
||||
if restore.Status == restores_core.RestoreStatusFailed {
|
||||
failMsg := "unknown error"
|
||||
if restore.FailMessage != nil {
|
||||
failMsg = *restore.FailMessage
|
||||
@@ -579,7 +577,7 @@ func connectToMysqlContainer(version tools.MysqlVersion, port string) (*MysqlCon
|
||||
dbName := "testdb"
|
||||
password := "rootpassword"
|
||||
username := "root"
|
||||
host := "127.0.0.1"
|
||||
host := config.GetEnv().TestLocalhost
|
||||
|
||||
portInt, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
|
||||
@@ -23,8 +23,7 @@ import (
|
||||
"databasus-backend/internal/features/databases"
|
||||
pgtypes "databasus-backend/internal/features/databases/databases/postgresql"
|
||||
"databasus-backend/internal/features/restores"
|
||||
restores_enums "databasus-backend/internal/features/restores/enums"
|
||||
restores_models "databasus-backend/internal/features/restores/models"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
@@ -125,6 +124,10 @@ func Test_BackupAndRestorePostgresqlWithEncryption_RestoreIsSuccessful(t *testin
|
||||
}
|
||||
|
||||
func Test_BackupAndRestoreSupabase_PublicSchemaOnly_RestoreIsSuccessful(t *testing.T) {
|
||||
if config.GetEnv().IsSkipExternalResourcesTests {
|
||||
t.Skip("Skipping Supabase test: IS_SKIP_EXTERNAL_RESOURCES_TESTS is true")
|
||||
}
|
||||
|
||||
env := config.GetEnv()
|
||||
|
||||
if env.TestSupabaseHost == "" {
|
||||
@@ -212,7 +215,7 @@ func Test_BackupAndRestoreSupabase_PublicSchemaOnly_RestoreIsSuccessful(t *testi
|
||||
)
|
||||
|
||||
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var countAfterRestore int
|
||||
err = supabaseDB.Get(
|
||||
@@ -439,7 +442,7 @@ func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string, cp
|
||||
)
|
||||
|
||||
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var tableExists bool
|
||||
err = newDB.Get(
|
||||
@@ -555,7 +558,7 @@ func testSchemaSelectionAllSchemasForVersion(t *testing.T, pgVersion string, por
|
||||
)
|
||||
|
||||
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var publicTableExists bool
|
||||
err = newDB.Get(&publicTableExists, `
|
||||
@@ -689,7 +692,7 @@ func testBackupRestoreWithExcludeExtensionsForVersion(t *testing.T, pgVersion st
|
||||
)
|
||||
|
||||
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
// Verify the table was restored
|
||||
var tableExists bool
|
||||
@@ -829,7 +832,7 @@ func testBackupRestoreWithoutExcludeExtensionsForVersion(
|
||||
)
|
||||
|
||||
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
// Verify the extension was recovered
|
||||
var extensionExists bool
|
||||
@@ -956,7 +959,7 @@ func testBackupRestoreWithReadOnlyUserForVersion(t *testing.T, pgVersion string,
|
||||
)
|
||||
|
||||
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var tableExists bool
|
||||
err = newDB.Get(
|
||||
@@ -1076,7 +1079,7 @@ func testSchemaSelectionOnlySpecifiedSchemasForVersion(
|
||||
)
|
||||
|
||||
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var publicTableExists bool
|
||||
err = newDB.Get(&publicTableExists, `
|
||||
@@ -1190,7 +1193,7 @@ func testBackupRestoreWithEncryptionForVersion(t *testing.T, pgVersion string, p
|
||||
)
|
||||
|
||||
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var tableExists bool
|
||||
err = newDB.Get(
|
||||
@@ -1286,7 +1289,7 @@ func waitForRestoreCompletion(
|
||||
backupID uuid.UUID,
|
||||
token string,
|
||||
timeout time.Duration,
|
||||
) *restores_models.Restore {
|
||||
) *restores_core.Restore {
|
||||
startTime := time.Now()
|
||||
pollInterval := 500 * time.Millisecond
|
||||
|
||||
@@ -1295,7 +1298,7 @@ func waitForRestoreCompletion(
|
||||
t.Fatalf("Timeout waiting for restore completion after %v", timeout)
|
||||
}
|
||||
|
||||
var restores []*restores_models.Restore
|
||||
var restores []*restores_core.Restore
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
@@ -1306,10 +1309,10 @@ func waitForRestoreCompletion(
|
||||
)
|
||||
|
||||
for _, restore := range restores {
|
||||
if restore.Status == restores_enums.RestoreStatusCompleted {
|
||||
if restore.Status == restores_core.RestoreStatusCompleted {
|
||||
return restore
|
||||
}
|
||||
if restore.Status == restores_enums.RestoreStatusFailed {
|
||||
if restore.Status == restores_core.RestoreStatusFailed {
|
||||
failMsg := "unknown error"
|
||||
if restore.FailMessage != nil {
|
||||
failMsg = *restore.FailMessage
|
||||
@@ -1476,7 +1479,7 @@ func createRestoreWithCpuCountViaAPI(
|
||||
cpuCount int,
|
||||
token string,
|
||||
) {
|
||||
request := restores.RestoreBackupRequest{
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
PostgresqlDatabase: &pgtypes.PostgresqlDatabase{
|
||||
Host: host,
|
||||
Port: port,
|
||||
@@ -1509,7 +1512,7 @@ func createRestoreWithOptionsViaAPI(
|
||||
isExcludeExtensions bool,
|
||||
token string,
|
||||
) {
|
||||
request := restores.RestoreBackupRequest{
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
PostgresqlDatabase: &pgtypes.PostgresqlDatabase{
|
||||
Host: host,
|
||||
Port: port,
|
||||
@@ -1647,7 +1650,7 @@ func createSupabaseRestoreViaAPI(
|
||||
database string,
|
||||
token string,
|
||||
) {
|
||||
request := restores.RestoreBackupRequest{
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
PostgresqlDatabase: &pgtypes.PostgresqlDatabase{
|
||||
Host: host,
|
||||
Port: port,
|
||||
@@ -1755,7 +1758,7 @@ func connectToPostgresContainer(version string, port string) (*PostgresContainer
|
||||
dbName := "testdb"
|
||||
password := "testpassword"
|
||||
username := "testuser"
|
||||
host := "localhost"
|
||||
host := config.GetEnv().TestLocalhost
|
||||
|
||||
portInt, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
"databasus-backend/internal/features/restores/restoring"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
)
|
||||
|
||||
@@ -12,11 +13,15 @@ func TestMain(m *testing.M) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
backuperNode := backuping.CreateTestBackuperNode()
|
||||
cancel := backuping.StartBackuperNodeForTest(&testing.T{}, backuperNode)
|
||||
cancelBackup := backuping.StartBackuperNodeForTest(&testing.T{}, backuperNode)
|
||||
|
||||
restorerNode := restoring.CreateTestRestorerNode()
|
||||
cancelRestore := restoring.StartRestorerNodeForTest(&testing.T{}, restorerNode)
|
||||
|
||||
exitCode := m.Run()
|
||||
|
||||
backuping.StopBackuperNodeForTest(&testing.T{}, cancel, backuperNode)
|
||||
backuping.StopBackuperNodeForTest(&testing.T{}, cancelBackup, backuperNode)
|
||||
restoring.StopRestorerNodeForTest(&testing.T{}, cancelRestore, restorerNode)
|
||||
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
93
backend/internal/util/cache/cache_test.go
vendored
Normal file
93
backend/internal/util/cache/cache_test.go
vendored
Normal file
@@ -0,0 +1,93 @@
|
||||
package cache_utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_ClearAllCache_AfterClear_CacheIsEmpty(t *testing.T) {
|
||||
client := getCache()
|
||||
|
||||
// Arrange: Set up multiple cache entries with different prefixes
|
||||
testKeys := []struct {
|
||||
prefix string
|
||||
key string
|
||||
value string
|
||||
}{
|
||||
{"test:user:", "user1", "John Doe"},
|
||||
{"test:user:", "user2", "Jane Smith"},
|
||||
{"test:session:", "session1", "abc123"},
|
||||
{"test:session:", "session2", "def456"},
|
||||
{"test:data:", "item1", "value1"},
|
||||
}
|
||||
|
||||
// Set all test keys
|
||||
for _, tk := range testKeys {
|
||||
cacheUtil := NewCacheUtil[string](client, tk.prefix)
|
||||
cacheUtil.Set(tk.key, &tk.value)
|
||||
}
|
||||
|
||||
// Verify keys were set correctly before clearing
|
||||
for _, tk := range testKeys {
|
||||
cacheUtil := NewCacheUtil[string](client, tk.prefix)
|
||||
retrieved := cacheUtil.Get(tk.key)
|
||||
assert.NotNil(t, retrieved, "Key %s should exist before clearing", tk.prefix+tk.key)
|
||||
assert.Equal(t, tk.value, *retrieved, "Retrieved value should match set value")
|
||||
}
|
||||
|
||||
// Act: Clear all cache
|
||||
err := ClearAllCache()
|
||||
|
||||
// Assert: No error returned
|
||||
assert.NoError(t, err, "ClearAllCache should not return an error")
|
||||
|
||||
// Assert: All keys should be deleted
|
||||
for _, tk := range testKeys {
|
||||
cacheUtil := NewCacheUtil[string](client, tk.prefix)
|
||||
retrieved := cacheUtil.Get(tk.key)
|
||||
assert.Nil(t, retrieved, "Key %s should be deleted after clearing", tk.prefix+tk.key)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SetWithExpiration_SetsCorrectTTL(t *testing.T) {
|
||||
client := getCache()
|
||||
|
||||
// Create a cache utility
|
||||
testPrefix := "test:ttl:"
|
||||
cacheUtil := NewCacheUtil[string](client, testPrefix)
|
||||
|
||||
// Set a value with 1-hour expiration
|
||||
testKey := "key1"
|
||||
testValue := "test value"
|
||||
oneHour := 1 * time.Hour
|
||||
|
||||
cacheUtil.SetWithExpiration(testKey, &testValue, oneHour)
|
||||
|
||||
// Verify the value was set
|
||||
retrieved := cacheUtil.Get(testKey)
|
||||
assert.NotNil(t, retrieved, "Value should be stored")
|
||||
assert.Equal(t, testValue, *retrieved, "Retrieved value should match")
|
||||
|
||||
// Check the TTL using Valkey TTL command
|
||||
ctx, cancel := context.WithTimeout(context.Background(), DefaultCacheTimeout)
|
||||
defer cancel()
|
||||
|
||||
fullKey := testPrefix + testKey
|
||||
ttlResult := client.Do(ctx, client.B().Ttl().Key(fullKey).Build())
|
||||
assert.NoError(t, ttlResult.Error(), "TTL command should not error")
|
||||
|
||||
ttlSeconds, err := ttlResult.AsInt64()
|
||||
assert.NoError(t, err, "TTL should be retrievable as int64")
|
||||
|
||||
// TTL should be approximately 1 hour (3600 seconds)
|
||||
// Allow for a small margin (within 10 seconds of 3600)
|
||||
expectedTTL := int64(3600)
|
||||
assert.GreaterOrEqual(t, ttlSeconds, expectedTTL-10, "TTL should be close to 1 hour")
|
||||
assert.LessOrEqual(t, ttlSeconds, expectedTTL, "TTL should not exceed 1 hour")
|
||||
|
||||
// Clean up
|
||||
cacheUtil.Invalidate(testKey)
|
||||
}
|
||||
37
backend/internal/util/cache/utils.go
vendored
37
backend/internal/util/cache/utils.go
vendored
@@ -67,6 +67,43 @@ func (c *CacheUtil[T]) Set(key string, item *T) {
|
||||
c.client.Do(ctx, c.client.B().Set().Key(fullKey).Value(string(data)).Ex(c.expiry).Build())
|
||||
}
|
||||
|
||||
func (c *CacheUtil[T]) SetWithExpiration(key string, item *T, expiry time.Duration) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
|
||||
defer cancel()
|
||||
|
||||
data, err := json.Marshal(item)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
fullKey := c.prefix + key
|
||||
c.client.Do(ctx, c.client.B().Set().Key(fullKey).Value(string(data)).Ex(expiry).Build())
|
||||
}
|
||||
|
||||
func (c *CacheUtil[T]) GetAndDelete(key string) *T {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
|
||||
defer cancel()
|
||||
|
||||
fullKey := c.prefix + key
|
||||
result := c.client.Do(ctx, c.client.B().Getdel().Key(fullKey).Build())
|
||||
|
||||
if result.Error() != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
data, err := result.AsBytes()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var item T
|
||||
if err := json.Unmarshal(data, &item); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &item
|
||||
}
|
||||
|
||||
func (c *CacheUtil[T]) Invalidate(key string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
|
||||
defer cancel()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user