Compare commits

...

41 Commits

Author SHA1 Message Date
Rostislav Dugin
d4763f26b2 Merge pull request #296 from databasus/develop
Develop
2026-01-19 19:27:03 +03:00
Rostislav Dugin
0e389ba16b FIX (backups): Allow parallel backups for different DBs 2026-01-19 19:26:03 +03:00
Rostislav Dugin
594a3294c6 FEATURE (limits): Add max backup size limit and total backups size limit 2026-01-19 19:26:03 +03:00
Rostislav Dugin
4e4a323cf1 FEATURE (config): Suggest read-only user creation when DB config changed 2026-01-19 19:26:03 +03:00
Rostislav Dugin
7d9ecf697b FIX (backups): Do not allow 2 parallel backups for the same DB 2026-01-19 19:26:03 +03:00
Rostislav Dugin
755c420157 Merge pull request #294 from databasus/develop
FIX (mysql \ mariadb): Add escaping underscoped DB names over heath c…
2026-01-19 12:07:18 +03:00
Rostislav Dugin
ff73627287 FIX (mysql \ mariadb): Add escaping underscoped DB names over heath check 2026-01-19 11:34:37 +03:00
Rostislav Dugin
9c9ab00ace Merge pull request #292 from databasus/develop
FIX (postgresql): Do not throw an error over read-only user creation …
2026-01-18 23:08:55 +03:00
Rostislav Dugin
7366e21a1a FIX (postgresql): Do not throw an error over read-only user creation if there are no public schema in DB 2026-01-18 22:57:47 +03:00
Rostislav Dugin
a327d1aa57 Merge pull request #290 from databasus/develop
FIX (ftp): Add support of nested folders
2026-01-18 18:34:45 +03:00
Rostislav Dugin
f152b16ea3 FIX (ftp): Add support of nested folders 2026-01-18 18:34:13 +03:00
Databasus
85dbe80d3d Merge pull request #288 from databasus/develop
FIX (email): Add following RFC 2047 for emails
2026-01-18 17:59:17 +03:00
Rostislav Dugin
edf4028fd1 FIX (email): Add following RFC 2047 for emails 2026-01-18 17:58:31 +03:00
Databasus
8d85c45a90 Merge pull request #287 from databasus/develop
FIX (tests): Allow to skip external network tests in CI CD
2026-01-18 15:46:49 +03:00
Rostislav Dugin
d9c176d19a FIX (tests): Allow to skip external network tests in CI CD 2026-01-18 15:45:49 +03:00
Databasus
7a6f72a456 Merge pull request #286 from databasus/develop
FIX (ci): Add cleanup to build and push steps
2026-01-18 15:09:13 +03:00
Rostislav Dugin
9a1471b88b FIX (ci): Add cleanup to build and push steps 2026-01-18 15:08:09 +03:00
Databasus
386ea1d708 Merge pull request #285 from databasus/develop
FIX (commit messages): Allow to use backstashes in messages x3
2026-01-18 14:58:10 +03:00
Rostislav Dugin
a4b23936ee FIX (commit messages): Allow to use backstashes in messages x3 2026-01-18 14:57:45 +03:00
Databasus
b36aa9d48b Merge pull request #284 from databasus/develop
FIX (commit messages): Allow to use backstashes in messages x2
2026-01-18 14:49:58 +03:00
Rostislav Dugin
13cb8e5bd2 FIX (commit messages): Allow to use backstashes in messages x2 2026-01-18 14:49:18 +03:00
Databasus
2db4b6e075 Merge pull request #283 from databasus/develop
FIX (commit messages): Allow to use backstashes in messages
2026-01-18 14:38:34 +03:00
Rostislav Dugin
f2b0b2bf1f FIX (commit messages): Allow to use backstashes in messages 2026-01-18 14:38:12 +03:00
Databasus
7142ce295e Merge pull request #282 from databasus/develop
Develop
2026-01-18 14:01:59 +03:00
Rostislav Dugin
04621b9b2d FEATURE (ci \ cd): Adjust CI \ CD to run heavy jobs on self hosted performant runner 2026-01-18 13:55:08 +03:00
Rostislav Dugin
bd329a68cf FEATURE (restores): Do not allow to make 2 parallel restores for single DB 2026-01-17 22:50:35 +03:00
Rostislav Dugin
f957abc9db FEATURE (restores): Add cancellation of restore process 2026-01-17 22:35:47 +03:00
Rostislav Dugin
c0fd6be1a9 Merge pull request #280 from databasus/develop
FEATURE (restores): Add support of multiple restores nodes
2026-01-17 13:59:36 +03:00
Rostislav Dugin
c39bd34d5e FEATURE (restores): Add support of multiple restores nodes 2026-01-17 13:59:06 +03:00
Rostislav Dugin
27bec15a29 Merge pull request #278 from databasus/develop
FIX (backups): Extend filtering lists to detect from-image DB access
2026-01-16 10:03:45 +03:00
Rostislav Dugin
d98baa0656 FIX (backups): Extend filtering lists to detect from-image DB access 2026-01-16 10:03:09 +03:00
Rostislav Dugin
4344f5ea5e Merge pull request #273 from databasus/develop
FIX (ci \ cd): Make DB files in CI \ CD executable
2026-01-15 22:17:06 +03:00
Rostislav Dugin
7c6afa5b88 FIX (ci \ cd): Make DB files in CI \ CD executable 2026-01-15 22:16:45 +03:00
Rostislav Dugin
dbac799e1b Merge pull request #272 from databasus/develop
FIX (backups): Add backups failure logging when it is expected
2026-01-15 22:02:39 +03:00
Rostislav Dugin
7ee3817089 FIX (backups): Add backups failure logging when it is expected 2026-01-15 22:01:53 +03:00
Rostislav Dugin
bae6f7f007 Merge pull request #271 from databasus/develop
Develop
2026-01-15 21:19:55 +03:00
Rostislav Dugin
55dc087ddd FIX (containers): Do not allow to backup internal DB from inside containers, instead give link to FAQ with manual how to backup Databasus in proper way 2026-01-15 21:18:37 +03:00
Rostislav Dugin
c94d0db637 FIX (ci \ cd): Remove caches and use assets from repo to avoid flucky tests over CI 2026-01-15 21:03:43 +03:00
Rostislav Dugin
a1adef2261 !REFACTOR (tasks): Move tasks cancellation and tracking to separate package from backuping to use for restores 2026-01-15 21:03:05 +03:00
Rostislav Dugin
4602dc3f88 Merge pull request #267 from databasus/develop
FIX (mysql): Enable allowCleartextPasswords over SSL
2026-01-14 18:13:46 +03:00
Rostislav Dugin
cbbfc5ea8f FIX (mysql): Enable allowCleartextPasswords over SSL 2026-01-14 18:11:49 +03:00
114 changed files with 11725 additions and 2813 deletions

View File

@@ -9,25 +9,26 @@ on:
jobs:
lint-backend:
runs-on: ubuntu-latest
runs-on: self-hosted
container:
image: golang:1.24.9
volumes:
- /runner-cache/go-pkg:/go/pkg/mod
- /runner-cache/go-build:/root/.cache/go-build
- /runner-cache/golangci-lint:/root/.cache/golangci-lint
- /runner-cache/apt-archives:/var/cache/apt/archives
steps:
- name: Check out code
uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: "1.24.9"
- name: Configure Git for container
run: |
git config --global --add safe.directory "$GITHUB_WORKSPACE"
- name: Cache Go modules
uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
~/.cache/go-build
key: ${{ runner.os }}-go-${{ hashFiles('backend/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: Download Go modules
run: |
cd backend
go mod download
- name: Install golangci-lint
run: |
@@ -63,8 +64,6 @@ jobs:
uses: actions/setup-node@v4
with:
node-version: "20"
cache: "npm"
cache-dependency-path: frontend/package-lock.json
- name: Install dependencies
run: |
@@ -93,8 +92,6 @@ jobs:
uses: actions/setup-node@v4
with:
node-version: "20"
cache: "npm"
cache-dependency-path: frontend/package-lock.json
- name: Install dependencies
run: |
@@ -107,44 +104,32 @@ jobs:
npm run test
test-backend:
runs-on: ubuntu-latest
runs-on: self-hosted
needs: [lint-backend]
container:
image: golang:1.24.9
options: --privileged -v /var/run/docker.sock:/var/run/docker.sock --add-host=host.docker.internal:host-gateway
volumes:
- /runner-cache/go-pkg:/go/pkg/mod
- /runner-cache/go-build:/root/.cache/go-build
- /runner-cache/apt-archives:/var/cache/apt/archives
steps:
- name: Free up disk space
- name: Install Docker CLI
run: |
echo "Disk space before cleanup:"
df -h
# Remove unnecessary pre-installed software
sudo rm -rf /usr/share/dotnet
sudo rm -rf /usr/local/lib/android
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL
sudo rm -rf /usr/local/share/boost
sudo rm -rf /usr/share/swift
# Clean apt cache
sudo apt-get clean
# Clean docker images (if any pre-installed)
docker system prune -af --volumes || true
echo "Disk space after cleanup:"
df -h
apt-get update -qq
apt-get install -y -qq docker.io docker-compose netcat-openbsd wget
- name: Check out code
uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: "1.24.9"
- name: Configure Git for container
run: |
git config --global --add safe.directory "$GITHUB_WORKSPACE"
- name: Cache Go modules
uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
~/.cache/go-build
key: ${{ runner.os }}-go-${{ hashFiles('backend/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: Download Go modules
run: |
cd backend
go mod download
- name: Create .env file for testing
run: |
@@ -156,14 +141,16 @@ jobs:
DEV_DB_PASSWORD=Q1234567
#app
ENV_MODE=development
# db
DATABASE_DSN=host=localhost user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
DATABASE_URL=postgres://postgres:Q1234567@localhost:5437/databasus?sslmode=disable
# db - using 172.17.0.1 to access host from container
DATABASE_DSN=host=172.17.0.1 user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
DATABASE_URL=postgres://postgres:Q1234567@172.17.0.1:5437/databasus?sslmode=disable
# migrations
GOOSE_DRIVER=postgres
GOOSE_DBSTRING=postgres://postgres:Q1234567@localhost:5437/databasus?sslmode=disable
GOOSE_DBSTRING=postgres://postgres:Q1234567@172.17.0.1:5437/databasus?sslmode=disable
GOOSE_MIGRATION_DIR=./migrations
# testing
# testing
TEST_LOCALHOST=172.17.0.1
IS_SKIP_EXTERNAL_RESOURCES_TESTS=true
# to get Google Drive env variables: add storage in UI and copy data from added storage here
TEST_GOOGLE_DRIVE_CLIENT_ID=${{ secrets.TEST_GOOGLE_DRIVE_CLIENT_ID }}
TEST_GOOGLE_DRIVE_CLIENT_SECRET=${{ secrets.TEST_GOOGLE_DRIVE_CLIENT_SECRET }}
@@ -221,12 +208,14 @@ jobs:
TEST_MONGODB_60_PORT=27060
TEST_MONGODB_70_PORT=27070
TEST_MONGODB_82_PORT=27082
# Valkey (cache)
VALKEY_HOST=localhost
# Valkey (cache) - using 172.17.0.1
VALKEY_HOST=172.17.0.1
VALKEY_PORT=6379
VALKEY_USERNAME=
VALKEY_PASSWORD=
VALKEY_IS_SSL=false
# Host for test databases (container -> host)
TEST_DB_HOST=172.17.0.1
EOF
- name: Start test containers
@@ -244,25 +233,25 @@ jobs:
timeout 60 bash -c 'until docker exec dev-valkey valkey-cli ping 2>/dev/null | grep -q PONG; do sleep 2; done'
echo "Valkey is ready!"
# Wait for test databases
timeout 60 bash -c 'until nc -z localhost 5000; do sleep 2; done'
timeout 60 bash -c 'until nc -z localhost 5001; do sleep 2; done'
timeout 60 bash -c 'until nc -z localhost 5002; do sleep 2; done'
timeout 60 bash -c 'until nc -z localhost 5003; do sleep 2; done'
timeout 60 bash -c 'until nc -z localhost 5004; do sleep 2; done'
timeout 60 bash -c 'until nc -z localhost 5005; do sleep 2; done'
# Wait for test databases (using 172.17.0.1 from container)
timeout 60 bash -c 'until nc -z 172.17.0.1 5000; do sleep 2; done'
timeout 60 bash -c 'until nc -z 172.17.0.1 5001; do sleep 2; done'
timeout 60 bash -c 'until nc -z 172.17.0.1 5002; do sleep 2; done'
timeout 60 bash -c 'until nc -z 172.17.0.1 5003; do sleep 2; done'
timeout 60 bash -c 'until nc -z 172.17.0.1 5004; do sleep 2; done'
timeout 60 bash -c 'until nc -z 172.17.0.1 5005; do sleep 2; done'
# Wait for MinIO
timeout 60 bash -c 'until nc -z localhost 9000; do sleep 2; done'
timeout 60 bash -c 'until nc -z 172.17.0.1 9000; do sleep 2; done'
# Wait for Azurite
timeout 60 bash -c 'until nc -z localhost 10000; do sleep 2; done'
timeout 60 bash -c 'until nc -z 172.17.0.1 10000; do sleep 2; done'
# Wait for FTP
timeout 60 bash -c 'until nc -z localhost 7007; do sleep 2; done'
timeout 60 bash -c 'until nc -z 172.17.0.1 7007; do sleep 2; done'
# Wait for SFTP
timeout 60 bash -c 'until nc -z localhost 7008; do sleep 2; done'
timeout 60 bash -c 'until nc -z 172.17.0.1 7008; do sleep 2; done'
# Wait for MySQL containers
echo "Waiting for MySQL 5.7..."
@@ -321,67 +310,66 @@ jobs:
mkdir -p databasus-data/backups
mkdir -p databasus-data/temp
- name: Cache PostgreSQL client tools
id: cache-postgres
uses: actions/cache@v4
with:
path: /usr/lib/postgresql
key: postgres-clients-12-18-v1
- name: Cache MySQL client tools
id: cache-mysql
uses: actions/cache@v4
with:
path: backend/tools/mysql
key: mysql-clients-57-80-84-9-v1
- name: Cache MariaDB client tools
id: cache-mariadb
uses: actions/cache@v4
with:
path: backend/tools/mariadb
key: mariadb-clients-106-121-v1
- name: Cache MongoDB Database Tools
id: cache-mongodb
uses: actions/cache@v4
with:
path: backend/tools/mongodb
key: mongodb-database-tools-100.10.0-v1
- name: Install MySQL dependencies
- name: Install database client dependencies
run: |
sudo apt-get update -qq
sudo apt-get install -y -qq libncurses6
sudo ln -sf /usr/lib/x86_64-linux-gnu/libncurses.so.6 /usr/lib/x86_64-linux-gnu/libncurses.so.5
sudo ln -sf /usr/lib/x86_64-linux-gnu/libtinfo.so.6 /usr/lib/x86_64-linux-gnu/libtinfo.so.5
apt-get update -qq
apt-get install -y -qq libncurses6 libpq5
ln -sf /usr/lib/x86_64-linux-gnu/libncurses.so.6 /usr/lib/x86_64-linux-gnu/libncurses.so.5 || true
ln -sf /usr/lib/x86_64-linux-gnu/libtinfo.so.6 /usr/lib/x86_64-linux-gnu/libtinfo.so.5 || true
- name: Install PostgreSQL, MySQL, MariaDB and MongoDB client tools
if: steps.cache-postgres.outputs.cache-hit != 'true' || steps.cache-mysql.outputs.cache-hit != 'true' || steps.cache-mariadb.outputs.cache-hit != 'true' || steps.cache-mongodb.outputs.cache-hit != 'true'
run: |
chmod +x backend/tools/download_linux.sh
cd backend/tools
./download_linux.sh
- name: Setup PostgreSQL symlinks (when using cache)
if: steps.cache-postgres.outputs.cache-hit == 'true'
- name: Setup PostgreSQL, MySQL and MariaDB client tools from pre-built assets
run: |
cd backend/tools
mkdir -p postgresql
# Create directory structure
mkdir -p postgresql mysql mariadb mongodb/bin
# Copy PostgreSQL client tools (12-18) from pre-built assets
for version in 12 13 14 15 16 17 18; do
version_dir="postgresql/postgresql-$version"
mkdir -p "$version_dir/bin"
pg_bin_dir="/usr/lib/postgresql/$version/bin"
if [ -d "$pg_bin_dir" ]; then
ln -sf "$pg_bin_dir/pg_dump" "$version_dir/bin/pg_dump"
ln -sf "$pg_bin_dir/pg_dumpall" "$version_dir/bin/pg_dumpall"
ln -sf "$pg_bin_dir/psql" "$version_dir/bin/psql"
ln -sf "$pg_bin_dir/pg_restore" "$version_dir/bin/pg_restore"
ln -sf "$pg_bin_dir/createdb" "$version_dir/bin/createdb"
ln -sf "$pg_bin_dir/dropdb" "$version_dir/bin/dropdb"
fi
mkdir -p postgresql/postgresql-$version
cp -r ../../assets/tools/x64/postgresql/postgresql-$version/bin postgresql/postgresql-$version/
done
# Copy MySQL client tools (5.7, 8.0, 8.4, 9) from pre-built assets
for version in 5.7 8.0 8.4 9; do
mkdir -p mysql/mysql-$version
cp -r ../../assets/tools/x64/mysql/mysql-$version/bin mysql/mysql-$version/
done
# Copy MariaDB client tools (10.6, 12.1) from pre-built assets
for version in 10.6 12.1; do
mkdir -p mariadb/mariadb-$version
cp -r ../../assets/tools/x64/mariadb/mariadb-$version/bin mariadb/mariadb-$version/
done
# Make all binaries executable
chmod +x postgresql/*/bin/*
chmod +x mysql/*/bin/*
chmod +x mariadb/*/bin/*
echo "Pre-built client tools setup complete"
- name: Install MongoDB Database Tools
run: |
cd backend/tools
# MongoDB Database Tools must be downloaded (not in pre-built assets)
# They are backward compatible - single version supports all servers (4.0-8.0)
MONGODB_TOOLS_URL="https://fastdl.mongodb.org/tools/db/mongodb-database-tools-debian12-x86_64-100.10.0.deb"
echo "Downloading MongoDB Database Tools..."
wget -q "$MONGODB_TOOLS_URL" -O /tmp/mongodb-database-tools.deb
echo "Installing MongoDB Database Tools..."
dpkg -i /tmp/mongodb-database-tools.deb || apt-get install -f -y --no-install-recommends
# Create symlinks to tools directory
ln -sf /usr/bin/mongodump mongodb/bin/mongodump
ln -sf /usr/bin/mongorestore mongodb/bin/mongorestore
rm -f /tmp/mongodb-database-tools.deb
echo "MongoDB Database Tools installed successfully"
- name: Verify MariaDB client tools exist
run: |
cd backend/tools
@@ -426,10 +414,28 @@ jobs:
if: always()
run: |
cd backend
# Stop and remove containers (keeping images for next run)
docker compose -f docker-compose.yml.example down -v
# Clean up all data directories created by docker-compose
echo "Cleaning up data directories..."
rm -rf pgdata || true
rm -rf valkey-data || true
rm -rf mysqldata || true
rm -rf mariadbdata || true
rm -rf temp/nas || true
rm -rf databasus-data || true
# Also clean root-level databasus-data if exists
cd ..
rm -rf databasus-data || true
echo "Cleanup complete"
determine-version:
runs-on: ubuntu-latest
runs-on: self-hosted
container:
image: node:20
needs: [test-backend, test-frontend]
if: ${{ github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message, '[skip-release]') }}
outputs:
@@ -442,10 +448,9 @@ jobs:
with:
fetch-depth: 0
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "20"
- name: Configure Git for container
run: |
git config --global --add safe.directory "$GITHUB_WORKSPACE"
- name: Install semver
run: npm install -g semver
@@ -459,6 +464,7 @@ jobs:
- name: Analyze commits and determine version bump
id: version_bump
shell: bash
run: |
CURRENT_VERSION="${{ steps.current_version.outputs.current_version }}"
LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
@@ -478,7 +484,7 @@ jobs:
HAS_FIX=false
HAS_BREAKING=false
# Analyze each commit
# Analyze each commit - USE PROCESS SUBSTITUTION to avoid subshell variable scope issues
while IFS= read -r commit; do
if [[ "$commit" =~ ^FEATURE ]]; then
HAS_FEATURE=true
@@ -496,7 +502,7 @@ jobs:
HAS_BREAKING=true
echo "Found BREAKING CHANGE: $commit"
fi
done <<< "$COMMITS"
done < <(printf '%s\n' "$COMMITS")
# Determine version bump
if [ "$HAS_BREAKING" = true ]; then
@@ -522,10 +528,15 @@ jobs:
fi
build-only:
runs-on: ubuntu-latest
runs-on: self-hosted
needs: [test-backend, test-frontend]
if: ${{ github.ref == 'refs/heads/main' && contains(github.event.head_commit.message, '[skip-release]') }}
steps:
- name: Clean workspace
run: |
sudo rm -rf "$GITHUB_WORKSPACE"/* || true
sudo rm -rf "$GITHUB_WORKSPACE"/.* || true
- name: Check out code
uses: actions/checkout@v4
@@ -554,12 +565,17 @@ jobs:
databasus/databasus:${{ github.sha }}
build-and-push:
runs-on: ubuntu-latest
runs-on: self-hosted
needs: [determine-version]
if: ${{ needs.determine-version.outputs.should_release == 'true' }}
permissions:
contents: write
steps:
- name: Clean workspace
run: |
sudo rm -rf "$GITHUB_WORKSPACE"/* || true
sudo rm -rf "$GITHUB_WORKSPACE"/.* || true
- name: Check out code
uses: actions/checkout@v4
@@ -589,21 +605,33 @@ jobs:
databasus/databasus:${{ github.sha }}
release:
runs-on: ubuntu-latest
runs-on: self-hosted
container:
image: node:20
needs: [determine-version, build-and-push]
if: ${{ needs.determine-version.outputs.should_release == 'true' }}
permissions:
contents: write
pull-requests: write
steps:
- name: Clean workspace
run: |
rm -rf "$GITHUB_WORKSPACE"/* || true
rm -rf "$GITHUB_WORKSPACE"/.* || true
- name: Check out code
uses: actions/checkout@v4
with:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Configure Git for container
run: |
git config --global --add safe.directory "$GITHUB_WORKSPACE"
- name: Generate changelog
id: changelog
shell: bash
run: |
NEW_VERSION="${{ needs.determine-version.outputs.new_version }}"
LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
@@ -623,6 +651,7 @@ jobs:
FIXES=""
REFACTORS=""
# USE PROCESS SUBSTITUTION to avoid subshell variable scope issues
while IFS= read -r line; do
if [ -n "$line" ]; then
COMMIT_MSG=$(echo "$line" | cut -d'|' -f1)
@@ -656,7 +685,7 @@ jobs:
fi
fi
fi
done <<< "$COMMITS"
done < <(printf '%s\n' "$COMMITS")
# Build changelog sections
if [ -n "$FEATURES" ]; then
@@ -695,16 +724,33 @@ jobs:
prerelease: false
publish-helm-chart:
runs-on: ubuntu-latest
runs-on: self-hosted
container:
image: alpine:3.19
volumes:
- /runner-cache/apk-cache:/etc/apk/cache
needs: [determine-version, build-and-push]
if: ${{ needs.determine-version.outputs.should_release == 'true' }}
permissions:
contents: read
packages: write
steps:
- name: Clean workspace
run: |
rm -rf "$GITHUB_WORKSPACE"/* || true
rm -rf "$GITHUB_WORKSPACE"/.* || true
- name: Install dependencies
run: |
apk add --no-cache git bash curl
- name: Check out code
uses: actions/checkout@v4
- name: Configure Git for container
run: |
git config --global --add safe.directory "$GITHUB_WORKSPACE"
- name: Set up Helm
uses: azure/setup-helm@v4
with:

1595
AGENTS.md Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,152 +0,0 @@
---
description:
globs:
alwaysApply: true
---
Always place private methods to the bottom of file
**This rule applies to ALL Go files including tests, services, controllers, repositories, etc.**
In Go, exported (public) functions/methods start with uppercase letters, while unexported (private) ones start with lowercase letters.
## Structure Order:
1. Type definitions and constants
2. Public methods/functions (uppercase)
3. Private methods/functions (lowercase)
## Examples:
### Service with methods:
```go
type UserService struct {
repository *UserRepository
}
// Public methods first
func (s *UserService) CreateUser(user *User) error {
if err := s.validateUser(user); err != nil {
return err
}
return s.repository.Save(user)
}
func (s *UserService) GetUser(id uuid.UUID) (*User, error) {
return s.repository.FindByID(id)
}
// Private methods at the bottom
func (s *UserService) validateUser(user *User) error {
if user.Name == "" {
return errors.New("name is required")
}
return nil
}
```
### Package-level functions:
```go
package utils
// Public functions first
func ProcessData(data []byte) (Result, error) {
cleaned := sanitizeInput(data)
return parseData(cleaned)
}
func ValidateInput(input string) bool {
return isValidFormat(input) && checkLength(input)
}
// Private functions at the bottom
func sanitizeInput(data []byte) []byte {
// implementation
}
func parseData(data []byte) (Result, error) {
// implementation
}
func isValidFormat(input string) bool {
// implementation
}
func checkLength(input string) bool {
// implementation
}
```
### Test files:
```go
package user_test
// Public test functions first
func Test_CreateUser_ValidInput_UserCreated(t *testing.T) {
user := createTestUser()
result, err := service.CreateUser(user)
assert.NoError(t, err)
assert.NotNil(t, result)
}
func Test_GetUser_ExistingUser_ReturnsUser(t *testing.T) {
user := createTestUser()
// test implementation
}
// Private helper functions at the bottom
func createTestUser() *User {
return &User{
Name: "Test User",
Email: "test@example.com",
}
}
func setupTestDatabase() *Database {
// setup implementation
}
```
### Controller example:
```go
type ProjectController struct {
service *ProjectService
}
// Public HTTP handlers first
func (c *ProjectController) CreateProject(ctx *gin.Context) {
var request CreateProjectRequest
if err := ctx.ShouldBindJSON(&request); err != nil {
c.handleError(ctx, err)
return
}
// handler logic
}
func (c *ProjectController) GetProject(ctx *gin.Context) {
projectID := c.extractProjectID(ctx)
// handler logic
}
// Private helper methods at the bottom
func (c *ProjectController) handleError(ctx *gin.Context, err error) {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
}
func (c *ProjectController) extractProjectID(ctx *gin.Context) uuid.UUID {
return uuid.MustParse(ctx.Param("projectId"))
}
```
## Key Points:
- **Exported/Public** = starts with uppercase letter (CreateUser, GetProject)
- **Unexported/Private** = starts with lowercase letter (validateUser, handleError)
- This improves code readability by showing the public API first
- Private helpers are implementation details, so they go at the bottom
- Apply this rule consistently across ALL Go files in the project

View File

@@ -1,45 +0,0 @@
---
description:
globs:
alwaysApply: true
---
## Comment Guidelines
1. **No obvious comments** - Don't state what the code already clearly shows
2. **Functions and variables should have meaningful names** - Code should be self-documenting
3. **Comments for unclear code only** - Only add comments when code logic isn't immediately clear
## Key Principles:
- **Code should tell a story** - Use descriptive variable and function names
- **Comments explain WHY, not WHAT** - The code shows what happens, comments explain business logic or complex decisions
- **Prefer refactoring over commenting** - If code needs explaining, consider making it clearer instead
- **API documentation is required** - Swagger comments for all HTTP endpoints are mandatory
- **Complex algorithms deserve comments** - Mathematical formulas, business rules, or non-obvious optimizations
Example of useless comment:
1.
```sql
// Create projects table
CREATE TABLE projects (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
name TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
```
2.
```go
// Create test project
project := CreateTestProject(projectName, user, router)
```
3.
```go
// CreateValidLogItems creates valid log items for testing
func CreateValidLogItems(count int, uniqueID string) []logs_receiving.LogItemRequestDTO {
```

View File

@@ -1,133 +0,0 @@
---
description:
globs:
alwaysApply: true
---
1. When we write controller:
- we combine all routes to single controller
- names them as .WhatWeDo (not "handlers") concept
2. We use gin and \*gin.Context for all routes.
Example:
func (c *TasksController) GetAvailableTasks(ctx *gin.Context) ...
3. We document all routes with Swagger in the following format:
package audit_logs
import (
"net/http"
user_models "databasus-backend/internal/features/users/models"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type AuditLogController struct {
auditLogService \*AuditLogService
}
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
// All audit log endpoints require authentication (handled in main.go)
auditRoutes := router.Group("/audit-logs")
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
}
// GetGlobalAuditLogs
// @Summary Get global audit logs (ADMIN only)
// @Description Retrieve all audit logs across the system
// @Tags audit-logs
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param limit query int false "Limit number of results" default(100)
// @Param offset query int false "Offset for pagination" default(0)
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
// @Success 200 {object} GetAuditLogsResponse
// @Failure 401 {object} map[string]string
// @Failure 403 {object} map[string]string
// @Router /audit-logs/global [get]
func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
user, isOk := ctx.MustGet("user").(\*user_models.User)
if !isOk {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
return
}
request := &GetAuditLogsRequest{}
if err := ctx.ShouldBindQuery(request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
if err != nil {
if err.Error() == "only administrators can view global audit logs" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
return
}
ctx.JSON(http.StatusOK, response)
}
// GetUserAuditLogs
// @Summary Get user audit logs
// @Description Retrieve audit logs for a specific user
// @Tags audit-logs
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param userId path string true "User ID"
// @Param limit query int false "Limit number of results" default(100)
// @Param offset query int false "Offset for pagination" default(0)
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
// @Success 200 {object} GetAuditLogsResponse
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 403 {object} map[string]string
// @Router /audit-logs/users/{userId} [get]
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
user, isOk := ctx.MustGet("user").(\*user_models.User)
if !isOk {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
return
}
userIDStr := ctx.Param("userId")
targetUserID, err := uuid.Parse(userIDStr)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
return
}
request := &GetAuditLogsRequest{}
if err := ctx.ShouldBindQuery(request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
if err != nil {
if err.Error() == "insufficient permissions to view user audit logs" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
return
}
ctx.JSON(http.StatusOK, response)
}

View File

@@ -1,671 +0,0 @@
---
alwaysApply: false
---
This is example of CRUD:
------ backend/internal/features/audit_logs/controller.go ------
```
package audit_logs
import (
"net/http"
user_models "databasus-backend/internal/features/users/models"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type AuditLogController struct {
auditLogService *AuditLogService
}
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
// All audit log endpoints require authentication (handled in main.go)
auditRoutes := router.Group("/audit-logs")
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
}
// GetGlobalAuditLogs
// @Summary Get global audit logs (ADMIN only)
// @Description Retrieve all audit logs across the system
// @Tags audit-logs
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param limit query int false "Limit number of results" default(100)
// @Param offset query int false "Offset for pagination" default(0)
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
// @Success 200 {object} GetAuditLogsResponse
// @Failure 401 {object} map[string]string
// @Failure 403 {object} map[string]string
// @Router /audit-logs/global [get]
func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
user, isOk := ctx.MustGet("user").(*user_models.User)
if !isOk {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
return
}
request := &GetAuditLogsRequest{}
if err := ctx.ShouldBindQuery(request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
if err != nil {
if err.Error() == "only administrators can view global audit logs" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
return
}
ctx.JSON(http.StatusOK, response)
}
// GetUserAuditLogs
// @Summary Get user audit logs
// @Description Retrieve audit logs for a specific user
// @Tags audit-logs
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param userId path string true "User ID"
// @Param limit query int false "Limit number of results" default(100)
// @Param offset query int false "Offset for pagination" default(0)
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
// @Success 200 {object} GetAuditLogsResponse
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 403 {object} map[string]string
// @Router /audit-logs/users/{userId} [get]
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
user, isOk := ctx.MustGet("user").(*user_models.User)
if !isOk {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
return
}
userIDStr := ctx.Param("userId")
targetUserID, err := uuid.Parse(userIDStr)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
return
}
request := &GetAuditLogsRequest{}
if err := ctx.ShouldBindQuery(request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
if err != nil {
if err.Error() == "insufficient permissions to view user audit logs" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
return
}
ctx.JSON(http.StatusOK, response)
}
```
------ backend/internal/features/audit_logs/controller_test.go ------
```
package audit_logs
import (
"fmt"
"net/http"
"testing"
"time"
user_enums "databasus-backend/internal/features/users/enums"
users_middleware "databasus-backend/internal/features/users/middleware"
users_services "databasus-backend/internal/features/users/services"
users_testing "databasus-backend/internal/features/users/testing"
"databasus-backend/internal/storage"
test_utils "databasus-backend/internal/util/testing"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_GetGlobalAuditLogs_AdminSucceedsAndMemberGetsForbidden(t *testing.T) {
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
memberUser := users_testing.CreateTestUser(user_enums.UserRoleMember)
router := createRouter()
service := GetAuditLogService()
projectID := uuid.New()
// Create test logs
createAuditLog(service, "Test log with user", &adminUser.UserID, nil)
createAuditLog(service, "Test log with project", nil, &projectID)
createAuditLog(service, "Test log standalone", nil, nil)
// Test ADMIN can access global logs
var response GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(t, router,
"/api/v1/audit-logs/global?limit=10", "Bearer "+adminUser.Token, http.StatusOK, &response)
assert.GreaterOrEqual(t, len(response.AuditLogs), 3)
assert.GreaterOrEqual(t, response.Total, int64(3))
messages := extractMessages(response.AuditLogs)
assert.Contains(t, messages, "Test log with user")
assert.Contains(t, messages, "Test log with project")
assert.Contains(t, messages, "Test log standalone")
// Test MEMBER cannot access global logs
resp := test_utils.MakeGetRequest(t, router, "/api/v1/audit-logs/global",
"Bearer "+memberUser.Token, http.StatusForbidden)
assert.Contains(t, string(resp.Body), "only administrators can view global audit logs")
}
func Test_GetUserAuditLogs_PermissionsEnforcedCorrectly(t *testing.T) {
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
router := createRouter()
service := GetAuditLogService()
projectID := uuid.New()
// Create test logs for different users
createAuditLog(service, "Test log user1 first", &user1.UserID, nil)
createAuditLog(service, "Test log user1 second", &user1.UserID, &projectID)
createAuditLog(service, "Test log user2 first", &user2.UserID, nil)
createAuditLog(service, "Test log user2 second", &user2.UserID, &projectID)
createAuditLog(service, "Test project log", nil, &projectID)
// Test ADMIN can view any user's logs
var user1Response GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(t, router,
fmt.Sprintf("/api/v1/audit-logs/users/%s?limit=10", user1.UserID.String()),
"Bearer "+adminUser.Token, http.StatusOK, &user1Response)
assert.Equal(t, 2, len(user1Response.AuditLogs))
messages := extractMessages(user1Response.AuditLogs)
assert.Contains(t, messages, "Test log user1 first")
assert.Contains(t, messages, "Test log user1 second")
// Test user can view own logs
var ownLogsResponse GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(t, router,
fmt.Sprintf("/api/v1/audit-logs/users/%s", user2.UserID.String()),
"Bearer "+user2.Token, http.StatusOK, &ownLogsResponse)
assert.Equal(t, 2, len(ownLogsResponse.AuditLogs))
// Test user cannot view other user's logs
resp := test_utils.MakeGetRequest(t, router,
fmt.Sprintf("/api/v1/audit-logs/users/%s", user1.UserID.String()),
"Bearer "+user2.Token, http.StatusForbidden)
assert.Contains(t, string(resp.Body), "insufficient permissions")
}
func Test_FilterAuditLogsByTime_ReturnsOnlyLogsBeforeDate(t *testing.T) {
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
router := createRouter()
service := GetAuditLogService()
db := storage.GetDb()
baseTime := time.Now().UTC()
// Create logs with different timestamps
createTimedLog(db, &adminUser.UserID, "Test old log", baseTime.Add(-2*time.Hour))
createTimedLog(db, &adminUser.UserID, "Test recent log", baseTime.Add(-30*time.Minute))
createAuditLog(service, "Test current log", &adminUser.UserID, nil)
// Test filtering - get logs before 1 hour ago
beforeTime := baseTime.Add(-1 * time.Hour)
var filteredResponse GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(t, router,
fmt.Sprintf("/api/v1/audit-logs/global?beforeDate=%s", beforeTime.Format(time.RFC3339)),
"Bearer "+adminUser.Token, http.StatusOK, &filteredResponse)
// Verify only old log is returned
messages := extractMessages(filteredResponse.AuditLogs)
assert.Contains(t, messages, "Test old log")
assert.NotContains(t, messages, "Test recent log")
assert.NotContains(t, messages, "Test current log")
// Test without filter - should get all logs
var allResponse GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(t, router, "/api/v1/audit-logs/global",
"Bearer "+adminUser.Token, http.StatusOK, &allResponse)
assert.GreaterOrEqual(t, len(allResponse.AuditLogs), 3)
}
func createRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
SetupDependencies()
v1 := router.Group("/api/v1")
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
GetAuditLogController().RegisterRoutes(protected.(*gin.RouterGroup))
return router
}
```
------ backend/internal/features/audit_logs/di.go ------
```
package audit_logs
import (
users_services "databasus-backend/internal/features/users/services"
"databasus-backend/internal/util/logger"
)
var auditLogRepository = &AuditLogRepository{}
var auditLogService = &AuditLogService{
auditLogRepository: auditLogRepository,
logger: logger.GetLogger(),
}
var auditLogController = &AuditLogController{
auditLogService: auditLogService,
}
func GetAuditLogService() *AuditLogService {
return auditLogService
}
func GetAuditLogController() *AuditLogController {
return auditLogController
}
func SetupDependencies() {
users_services.GetUserService().SetAuditLogWriter(auditLogService)
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
}
```
------ backend/internal/features/audit_logs/dto.go ------
```
package audit_logs
import "time"
type GetAuditLogsRequest struct {
Limit int `form:"limit" json:"limit"`
Offset int `form:"offset" json:"offset"`
BeforeDate *time.Time `form:"beforeDate" json:"beforeDate"`
}
type GetAuditLogsResponse struct {
AuditLogs []*AuditLog `json:"auditLogs"`
Total int64 `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
}
```
------ backend/internal/features/audit_logs/models.go ------
```
package audit_logs
import (
"time"
"github.com/google/uuid"
)
type AuditLog struct {
ID uuid.UUID `json:"id" gorm:"column:id"`
UserID *uuid.UUID `json:"userId" gorm:"column:user_id"`
ProjectID *uuid.UUID `json:"projectId" gorm:"column:project_id"`
Message string `json:"message" gorm:"column:message"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
}
func (AuditLog) TableName() string {
return "audit_logs"
}
```
------ backend/internal/features/audit_logs/repository.go ------
```
package audit_logs
import (
"databasus-backend/internal/storage"
"time"
"github.com/google/uuid"
)
type AuditLogRepository struct{}
func (r *AuditLogRepository) Create(auditLog *AuditLog) error {
if auditLog.ID == uuid.Nil {
auditLog.ID = uuid.New()
}
return storage.GetDb().Create(auditLog).Error
}
func (r *AuditLogRepository) GetGlobal(limit, offset int, beforeDate *time.Time) ([]*AuditLog, error) {
var auditLogs []*AuditLog
query := storage.GetDb().Order("created_at DESC")
if beforeDate != nil {
query = query.Where("created_at < ?", *beforeDate)
}
err := query.
Limit(limit).
Offset(offset).
Find(&auditLogs).Error
return auditLogs, err
}
func (r *AuditLogRepository) GetByUser(
userID uuid.UUID,
limit, offset int,
beforeDate *time.Time,
) ([]*AuditLog, error) {
var auditLogs []*AuditLog
query := storage.GetDb().
Where("user_id = ?", userID).
Order("created_at DESC")
if beforeDate != nil {
query = query.Where("created_at < ?", *beforeDate)
}
err := query.
Limit(limit).
Offset(offset).
Find(&auditLogs).Error
return auditLogs, err
}
func (r *AuditLogRepository) GetByProject(
projectID uuid.UUID,
limit, offset int,
beforeDate *time.Time,
) ([]*AuditLog, error) {
var auditLogs []*AuditLog
query := storage.GetDb().
Where("project_id = ?", projectID).
Order("created_at DESC")
if beforeDate != nil {
query = query.Where("created_at < ?", *beforeDate)
}
err := query.
Limit(limit).
Offset(offset).
Find(&auditLogs).Error
return auditLogs, err
}
func (r *AuditLogRepository) CountGlobal(beforeDate *time.Time) (int64, error) {
var count int64
query := storage.GetDb().Model(&AuditLog{})
if beforeDate != nil {
query = query.Where("created_at < ?", *beforeDate)
}
err := query.Count(&count).Error
return count, err
}
```
------ backend/internal/features/audit_logs/service.go ------
```
package audit_logs
import (
"errors"
"log/slog"
"time"
user_enums "databasus-backend/internal/features/users/enums"
user_models "databasus-backend/internal/features/users/models"
"github.com/google/uuid"
)
type AuditLogService struct {
auditLogRepository *AuditLogRepository
logger *slog.Logger
}
func (s *AuditLogService) WriteAuditLog(
message string,
userID *uuid.UUID,
projectID *uuid.UUID,
) {
auditLog := &AuditLog{
UserID: userID,
ProjectID: projectID,
Message: message,
CreatedAt: time.Now().UTC(),
}
err := s.auditLogRepository.Create(auditLog)
if err != nil {
s.logger.Error("failed to create audit log", "error", err)
return
}
}
func (s *AuditLogService) CreateAuditLog(auditLog *AuditLog) error {
return s.auditLogRepository.Create(auditLog)
}
func (s *AuditLogService) GetGlobalAuditLogs(
user *user_models.User,
request *GetAuditLogsRequest,
) (*GetAuditLogsResponse, error) {
if user.Role != user_enums.UserRoleAdmin {
return nil, errors.New("only administrators can view global audit logs")
}
limit := request.Limit
if limit <= 0 || limit > 1000 {
limit = 100
}
offset := max(request.Offset, 0)
auditLogs, err := s.auditLogRepository.GetGlobal(limit, offset, request.BeforeDate)
if err != nil {
return nil, err
}
total, err := s.auditLogRepository.CountGlobal(request.BeforeDate)
if err != nil {
return nil, err
}
return &GetAuditLogsResponse{
AuditLogs: auditLogs,
Total: total,
Limit: limit,
Offset: offset,
}, nil
}
func (s *AuditLogService) GetUserAuditLogs(
targetUserID uuid.UUID,
user *user_models.User,
request *GetAuditLogsRequest,
) (*GetAuditLogsResponse, error) {
// Users can view their own logs, ADMIN can view any user's logs
if user.Role != user_enums.UserRoleAdmin && user.ID != targetUserID {
return nil, errors.New("insufficient permissions to view user audit logs")
}
limit := request.Limit
if limit <= 0 || limit > 1000 {
limit = 100
}
offset := max(request.Offset, 0)
auditLogs, err := s.auditLogRepository.GetByUser(targetUserID, limit, offset, request.BeforeDate)
if err != nil {
return nil, err
}
return &GetAuditLogsResponse{
AuditLogs: auditLogs,
Total: int64(len(auditLogs)),
Limit: limit,
Offset: offset,
}, nil
}
func (s *AuditLogService) GetProjectAuditLogs(
projectID uuid.UUID,
request *GetAuditLogsRequest,
) (*GetAuditLogsResponse, error) {
limit := request.Limit
if limit <= 0 || limit > 1000 {
limit = 100
}
offset := max(request.Offset, 0)
auditLogs, err := s.auditLogRepository.GetByProject(projectID, limit, offset, request.BeforeDate)
if err != nil {
return nil, err
}
return &GetAuditLogsResponse{
AuditLogs: auditLogs,
Total: int64(len(auditLogs)),
Limit: limit,
Offset: offset,
}, nil
}
```
------ backend/internal/features/audit_logs/service_test.go ------
```
package audit_logs
import (
"testing"
"time"
user_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"gorm.io/gorm"
)
func Test_AuditLogs_ProjectSpecificLogs(t *testing.T) {
service := GetAuditLogService()
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
project1ID, project2ID := uuid.New(), uuid.New()
// Create test logs for projects
createAuditLog(service, "Test project1 log first", &user1.UserID, &project1ID)
createAuditLog(service, "Test project1 log second", &user2.UserID, &project1ID)
createAuditLog(service, "Test project2 log first", &user1.UserID, &project2ID)
createAuditLog(service, "Test project2 log second", &user2.UserID, &project2ID)
createAuditLog(service, "Test no project log", &user1.UserID, nil)
request := &GetAuditLogsRequest{Limit: 10, Offset: 0}
// Test project 1 logs
project1Response, err := service.GetProjectAuditLogs(project1ID, request)
assert.NoError(t, err)
assert.Equal(t, 2, len(project1Response.AuditLogs))
messages := extractMessages(project1Response.AuditLogs)
assert.Contains(t, messages, "Test project1 log first")
assert.Contains(t, messages, "Test project1 log second")
for _, log := range project1Response.AuditLogs {
assert.Equal(t, &project1ID, log.ProjectID)
}
// Test project 2 logs
project2Response, err := service.GetProjectAuditLogs(project2ID, request)
assert.NoError(t, err)
assert.Equal(t, 2, len(project2Response.AuditLogs))
messages2 := extractMessages(project2Response.AuditLogs)
assert.Contains(t, messages2, "Test project2 log first")
assert.Contains(t, messages2, "Test project2 log second")
// Test pagination
limitedResponse, err := service.GetProjectAuditLogs(project1ID,
&GetAuditLogsRequest{Limit: 1, Offset: 0})
assert.NoError(t, err)
assert.Equal(t, 1, len(limitedResponse.AuditLogs))
assert.Equal(t, 1, limitedResponse.Limit)
// Test beforeDate filter
beforeTime := time.Now().UTC().Add(-1 * time.Minute)
filteredResponse, err := service.GetProjectAuditLogs(project1ID,
&GetAuditLogsRequest{Limit: 10, BeforeDate: &beforeTime})
assert.NoError(t, err)
for _, log := range filteredResponse.AuditLogs {
assert.True(t, log.CreatedAt.Before(beforeTime))
}
}
func createAuditLog(service *AuditLogService, message string, userID, projectID *uuid.UUID) {
service.WriteAuditLog(message, userID, projectID)
}
func extractMessages(logs []*AuditLog) []string {
messages := make([]string, len(logs))
for i, log := range logs {
messages[i] = log.Message
}
return messages
}
func createTimedLog(db *gorm.DB, userID *uuid.UUID, message string, createdAt time.Time) {
log := &AuditLog{
ID: uuid.New(),
UserID: userID,
Message: message,
CreatedAt: createdAt,
}
db.Create(log)
}
```

View File

@@ -1,74 +0,0 @@
---
description:
globs:
alwaysApply: true
---
For DI files use implicit fields declaration styles (espesially
for controllers, services, repositories, use cases, etc., not simple
data structures).
So, instead of:
var orderController = &OrderController{
orderService: orderService,
botUserService: bot_users.GetBotUserService(),
botService: bots.GetBotService(),
userService: users.GetUserService(),
}
Use:
var orderController = &OrderController{
orderService,
bot_users.GetBotUserService(),
bots.GetBotService(),
users.GetUserService(),
}
This is needed to avoid forgetting to update DI style
when we add new dependency.
---
Please force such usage if file look like this (see some
services\controllers\repos definitions and getters):
var orderBackgroundService = &OrderBackgroundService{
orderService: orderService,
orderPaymentRepository: orderPaymentRepository,
botService: bots.GetBotService(),
paymentSettingsService: payment_settings.GetPaymentSettingsService(),
orderSubscriptionListeners: []OrderSubscriptionListener{},
}
var orderController = &OrderController{
orderService: orderService,
botUserService: bot_users.GetBotUserService(),
botService: bots.GetBotService(),
userService: users.GetUserService(),
}
func GetUniquePaymentRepository() *repositories.UniquePaymentRepository {
return uniquePaymentRepository
}
func GetOrderPaymentRepository() *repositories.OrderPaymentRepository {
return orderPaymentRepository
}
func GetOrderService() *OrderService {
return orderService
}
func GetOrderController() *OrderController {
return orderController
}
func GetOrderBackgroundService() *OrderBackgroundService {
return orderBackgroundService
}
func GetOrderRepository() *repositories.OrderRepository {
return orderRepository
}

View File

@@ -1,27 +0,0 @@
---
description:
globs:
alwaysApply: true
---
When writting migrations:
- write them for PostgreSQL
- for PRIMARY UUID keys use gen_random_uuid()
- for time use TIMESTAMPTZ (timestamp with zone)
- split table, constraint and indexes declaration (table first, them other one by one)
- format SQL in pretty way (add spaces, align columns types), constraints split by lines. The example:
CREATE TABLE marketplace_info (
bot_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
title TEXT NOT NULL,
description TEXT NOT NULL,
short_description TEXT NOT NULL,
tutorial_url TEXT,
info_order BIGINT NOT NULL DEFAULT 0,
is_published BOOLEAN NOT NULL DEFAULT FALSE
);
ALTER TABLE marketplace_info_images
ADD CONSTRAINT fk_marketplace_info_images_bot_id
FOREIGN KEY (bot_id)
REFERENCES marketplace_info (bot_id);

View File

@@ -1,12 +0,0 @@
---
description:
globs:
alwaysApply: true
---
When applying changes, do not forget to refactor old code.
You can shortify, make more readable, improve code quality, etc.
Common logic can be extracted to functions, constants, files, etc.
After each large change with more than ~50-100 lines of code - always run `make lint` (from backend root folder) and, if you change frontend, run `npm run format` (from frontend root folder).

View File

@@ -1,147 +0,0 @@
---
description:
globs:
alwaysApply: true
---
After writing tests, always launch them and verify that they pass.
## Test Naming Format
Use these naming patterns:
- `Test_WhatWeDo_WhatWeExpect`
- `Test_WhatWeDo_WhichConditions_WhatWeExpect`
## Examples from Real Codebase:
- `Test_CreateApiKey_WhenUserIsProjectOwner_ApiKeyCreated`
- `Test_UpdateProject_WhenUserIsProjectAdmin_ProjectUpdated`
- `Test_DeleteApiKey_WhenUserIsProjectMember_ReturnsForbidden`
- `Test_GetProjectAuditLogs_WithDifferentUserRoles_EnforcesPermissionsCorrectly`
- `Test_ProjectLifecycleE2E_CompletesSuccessfully`
## Testing Philosophy
**Prefer Controllers Over Unit Tests:**
- Test through HTTP endpoints via controllers whenever possible
- Avoid testing repositories, services in isolation - test via API instead
- Only use unit tests for complex model logic when no API exists
- Name test files `controller_test.go` or `service_test.go`, not `integration_test.go`
**Extract Common Logic to Testing Utilities:**
- Create `testing.go` or `testing/testing.go` files for shared test utilities
- Extract router creation, user setup, models creation helpers (in API, not just structs creation)
- Reuse common patterns across different test files
**Refactor Existing Tests:**
- When working with existing tests, always look for opportunities to refactor and improve
- Extract repetitive setup code to common utilities
- Simplify complex tests by breaking them into smaller, focused tests
- Replace inline test data creation with reusable helper functions
- Consolidate similar test patterns across different test files
- Make tests more readable and maintainable for other developers
## Testing Utilities Structure
**Create `testing.go` or `testing/testing.go` files with common utilities:**
```go
package projects_testing
// CreateTestRouter creates unified router for all controllers
func CreateTestRouter(controllers ...ControllerInterface) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
v1 := router.Group("/api/v1")
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
for _, controller := range controllers {
if routerGroup, ok := protected.(*gin.RouterGroup); ok {
controller.RegisterRoutes(routerGroup)
}
}
return router
}
// CreateTestProjectViaAPI creates project through HTTP API
func CreateTestProjectViaAPI(name string, owner *users_dto.SignInResponseDTO, router *gin.Engine) (*projects_models.Project, string) {
request := projects_dto.CreateProjectRequestDTO{Name: name}
w := MakeAPIRequest(router, "POST", "/api/v1/projects", "Bearer "+owner.Token, request)
// Handle response...
return project, owner.Token
}
// AddMemberToProject adds member via API call
func AddMemberToProject(project *projects_models.Project, member *users_dto.SignInResponseDTO, role users_enums.ProjectRole, ownerToken string, router *gin.Engine) {
// Implementation...
}
```
## Controller Test Examples
**Permission-based testing:**
```go
func Test_CreateApiKey_WhenUserIsProjectOwner_ApiKeyCreated(t *testing.T) {
router := CreateApiKeyTestRouter(GetProjectController(), GetMembershipController())
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
project, _ := projects_testing.CreateTestProjectViaAPI("Test Project", owner, router)
request := CreateApiKeyRequestDTO{Name: "Test API Key"}
var response ApiKey
test_utils.MakePostRequestAndUnmarshal(t, router, "/api/v1/projects/api-keys/"+project.ID.String(), "Bearer "+owner.Token, request, http.StatusOK, &response)
assert.Equal(t, "Test API Key", response.Name)
assert.NotEmpty(t, response.Token)
}
```
**Cross-project security testing:**
```go
func Test_UpdateApiKey_WithApiKeyFromDifferentProject_ReturnsBadRequest(t *testing.T) {
router := CreateApiKeyTestRouter(GetProjectController(), GetMembershipController())
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
project1, _ := projects_testing.CreateTestProjectViaAPI("Project 1", owner1, router)
project2, _ := projects_testing.CreateTestProjectViaAPI("Project 2", owner2, router)
apiKey := CreateTestApiKey("Cross Project Key", project1.ID, owner1.Token, router)
// Try to update via different project endpoint
request := UpdateApiKeyRequestDTO{Name: &"Hacked Key"}
resp := test_utils.MakePutRequest(t, router, "/api/v1/projects/api-keys/"+project2.ID.String()+"/"+apiKey.ID.String(), "Bearer "+owner2.Token, request, http.StatusBadRequest)
assert.Contains(t, string(resp.Body), "API key does not belong to this project")
}
```
**E2E lifecycle testing:**
```go
func Test_ProjectLifecycleE2E_CompletesSuccessfully(t *testing.T) {
router := projects_testing.CreateTestRouter(GetProjectController(), GetMembershipController())
// 1. Create project
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
project := projects_testing.CreateTestProject("E2E Project", owner, router)
// 2. Add member
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
projects_testing.AddMemberToProject(project, member, users_enums.ProjectRoleMember, owner.Token, router)
// 3. Promote to admin
projects_testing.ChangeMemberRole(project, member.UserID, users_enums.ProjectRoleAdmin, owner.Token, router)
// 4. Transfer ownership
projects_testing.TransferProjectOwnership(project, member.UserID, owner.Token, router)
// 5. Verify new owner can manage project
finalProject := projects_testing.GetProject(project.ID, member.Token, router)
assert.Equal(t, project.ID, finalProject.ID)
}
```

View File

@@ -1,6 +0,0 @@
---
description:
globs:
alwaysApply: true
---
Always use time.Now().UTC() instead of time.Now()

View File

@@ -2,8 +2,13 @@
DEV_DB_NAME=databasus
DEV_DB_USERNAME=postgres
DEV_DB_PASSWORD=Q1234567
#app
# app
ENV_MODE=development
# logging
SHOW_DB_INSTALLATION_VERIFICATION_LOGS=true
# tests
TEST_LOCALHOST=localhost
IS_SKIP_EXTERNAL_RESOURCES_TESTS=false
# db
DATABASE_DSN=host=dev-db user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
DATABASE_URL=postgres://postgres:Q1234567@dev-db:5437/databasus?sslmode=disable

View File

@@ -16,7 +16,6 @@ import (
"databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
"databasus-backend/internal/features/backups/backups/backuping"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_download "databasus-backend/internal/features/backups/backups/download"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -26,8 +25,10 @@ import (
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/restores"
"databasus-backend/internal/features/restores/restoring"
"databasus-backend/internal/features/storages"
system_healthcheck "databasus-backend/internal/features/system/healthcheck"
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
users_controllers "databasus-backend/internal/features/users/controllers"
users_middleware "databasus-backend/internal/features/users/middleware"
users_services "databasus-backend/internal/features/users/services"
@@ -59,6 +60,8 @@ func main() {
cache_utils.TestCacheConnection()
if config.GetEnv().IsPrimaryNode {
log.Info("Clearing cache...")
err := cache_utils.ClearAllCache()
if err != nil {
log.Error("Failed to clear cache", "error", err)
@@ -239,7 +242,7 @@ func setUpDependencies() {
notifiers.SetupDependencies()
storages.SetupDependencies()
backups_config.SetupDependencies()
backups_cancellation.SetupDependencies()
task_cancellation.SetupDependencies()
}
func runBackgroundTasks(log *slog.Logger) {
@@ -257,20 +260,24 @@ func runBackgroundTasks(log *slog.Logger) {
cancel()
}()
err := files_utils.CleanFolder(config.GetEnv().TempFolder)
if err != nil {
log.Error("Failed to clean temp folder", "error", err)
}
if config.GetEnv().IsPrimaryNode {
log.Info("Starting primary node background tasks...")
err := files_utils.CleanFolder(config.GetEnv().TempFolder)
if err != nil {
log.Error("Failed to clean temp folder", "error", err)
}
go runWithPanicLogging(log, "backup background service", func() {
backuping.GetBackupsScheduler().Run(ctx)
})
go runWithPanicLogging(log, "backup cleaner background service", func() {
backuping.GetBackupCleaner().Run(ctx)
})
go runWithPanicLogging(log, "restore background service", func() {
restores.GetRestoreBackgroundService().Run(ctx)
restoring.GetRestoresScheduler().Run(ctx)
})
go runWithPanicLogging(log, "healthcheck attempt background service", func() {
@@ -284,18 +291,30 @@ func runBackgroundTasks(log *slog.Logger) {
go runWithPanicLogging(log, "download token cleanup background service", func() {
backups_download.GetDownloadTokenBackgroundService().Run(ctx)
})
go runWithPanicLogging(log, "backup nodes registry background service", func() {
backuping.GetBackupNodesRegistry().Run(ctx)
})
go runWithPanicLogging(log, "restore nodes registry background service", func() {
restoring.GetRestoreNodesRegistry().Run(ctx)
})
} else {
log.Info("Skipping primary node tasks as not primary node")
}
if config.GetEnv().IsBackupNode {
if config.GetEnv().IsProcessingNode {
log.Info("Starting backup node background tasks...")
go runWithPanicLogging(log, "backup node", func() {
backuping.GetBackuperNode().Run(ctx)
})
go runWithPanicLogging(log, "restore node", func() {
restoring.GetRestorerNode().Run(ctx)
})
} else {
log.Info("Skipping backup node tasks as not backup node")
log.Info("Skipping backup/restore node tasks as not backup node")
}
}

View File

@@ -9,7 +9,6 @@ import (
"strings"
"sync"
"github.com/google/uuid"
"github.com/ilyakaznacheev/cleanenv"
"github.com/joho/godotenv"
)
@@ -30,10 +29,14 @@ type EnvVariables struct {
MariadbInstallDir string `env:"MARIADB_INSTALL_DIR"`
MongodbInstallDir string `env:"MONGODB_INSTALL_DIR"`
NodeID string
TestLocalhost string `env:"TEST_LOCALHOST"`
ShowDbInstallationVerificationLogs bool `env:"SHOW_DB_INSTALLATION_VERIFICATION_LOGS"`
IsSkipExternalResourcesTests bool `env:"IS_SKIP_EXTERNAL_RESOURCES_TESTS"`
IsManyNodesMode bool `env:"IS_MANY_NODES_MODE"`
IsPrimaryNode bool `env:"IS_PRIMARY_NODE"`
IsBackupNode bool `env:"IS_BACKUP_NODE"`
IsProcessingNode bool `env:"IS_PROCESSING_NODE"`
NodeNetworkThroughputMBs int `env:"NODE_NETWORK_THROUGHPUT_MBPS"`
DataFolder string
@@ -169,6 +172,16 @@ func loadEnvVariables() {
os.Exit(1)
}
// Set default value for ShowDbInstallationVerificationLogs if not defined
if os.Getenv("SHOW_DB_INSTALLATION_VERIFICATION_LOGS") == "" {
env.ShowDbInstallationVerificationLogs = true
}
// Set default value for IsSkipExternalTests if not defined
if os.Getenv("IS_SKIP_EXTERNAL_RESOURCES_TESTS") == "" {
env.IsSkipExternalResourcesTests = false
}
for _, arg := range os.Args {
if strings.Contains(arg, "test") {
env.IsTesting = true
@@ -192,25 +205,48 @@ func loadEnvVariables() {
log.Info("ENV_MODE loaded", "mode", env.EnvMode)
env.PostgresesInstallDir = filepath.Join(backendRoot, "tools", "postgresql")
tools.VerifyPostgresesInstallation(log, env.EnvMode, env.PostgresesInstallDir)
tools.VerifyPostgresesInstallation(
log,
env.EnvMode,
env.PostgresesInstallDir,
env.ShowDbInstallationVerificationLogs,
)
env.MysqlInstallDir = filepath.Join(backendRoot, "tools", "mysql")
tools.VerifyMysqlInstallation(log, env.EnvMode, env.MysqlInstallDir)
tools.VerifyMysqlInstallation(
log,
env.EnvMode,
env.MysqlInstallDir,
env.ShowDbInstallationVerificationLogs,
)
env.MariadbInstallDir = filepath.Join(backendRoot, "tools", "mariadb")
tools.VerifyMariadbInstallation(log, env.EnvMode, env.MariadbInstallDir)
tools.VerifyMariadbInstallation(
log,
env.EnvMode,
env.MariadbInstallDir,
env.ShowDbInstallationVerificationLogs,
)
env.MongodbInstallDir = filepath.Join(backendRoot, "tools", "mongodb")
tools.VerifyMongodbInstallation(log, env.EnvMode, env.MongodbInstallDir)
tools.VerifyMongodbInstallation(
log,
env.EnvMode,
env.MongodbInstallDir,
env.ShowDbInstallationVerificationLogs,
)
env.NodeID = uuid.New().String()
if env.NodeNetworkThroughputMBs == 0 {
env.NodeNetworkThroughputMBs = 125 // 1 Gbit/s
}
if !env.IsManyNodesMode {
env.IsPrimaryNode = true
env.IsBackupNode = true
env.IsProcessingNode = true
}
if env.TestLocalhost == "" {
env.TestLocalhost = "localhost"
}
// Valkey

View File

@@ -2,34 +2,50 @@ package audit_logs
import (
"context"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
)
type AuditLogBackgroundService struct {
auditLogService *AuditLogService
logger *slog.Logger
runOnce sync.Once
hasRun atomic.Bool
}
func (s *AuditLogBackgroundService) Run(ctx context.Context) {
s.logger.Info("Starting audit log cleanup background service")
wasAlreadyRun := s.hasRun.Load()
if ctx.Err() != nil {
return
}
s.runOnce.Do(func() {
s.hasRun.Store(true)
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
s.logger.Info("Starting audit log cleanup background service")
for {
select {
case <-ctx.Done():
if ctx.Err() != nil {
return
case <-ticker.C:
if err := s.cleanOldAuditLogs(); err != nil {
s.logger.Error("Failed to clean old audit logs", "error", err)
}
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.cleanOldAuditLogs(); err != nil {
s.logger.Error("Failed to clean old audit logs", "error", err)
}
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
}

View File

@@ -1,6 +1,9 @@
package audit_logs
import (
"sync"
"sync/atomic"
users_services "databasus-backend/internal/features/users/services"
"databasus-backend/internal/util/logger"
)
@@ -14,8 +17,10 @@ var auditLogController = &AuditLogController{
auditLogService,
}
var auditLogBackgroundService = &AuditLogBackgroundService{
auditLogService,
logger.GetLogger(),
auditLogService: auditLogService,
logger: logger.GetLogger(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
func GetAuditLogService() *AuditLogService {
@@ -30,8 +35,23 @@ func GetAuditLogBackgroundService() *AuditLogBackgroundService {
return auditLogBackgroundService
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
users_services.GetUserService().SetAuditLogWriter(auditLogService)
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
users_services.GetUserService().SetAuditLogWriter(auditLogService)
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -2,22 +2,30 @@ package backuping
import (
"context"
"databasus-backend/internal/config"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/storages"
workspaces_services "databasus-backend/internal/features/workspaces/services"
util_encryption "databasus-backend/internal/util/encryption"
"errors"
"fmt"
"log/slog"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"databasus-backend/internal/config"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/storages"
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
workspaces_services "databasus-backend/internal/features/workspaces/services"
util_encryption "databasus-backend/internal/util/encryption"
)
const (
heartbeatTickerInterval = 15 * time.Second
backuperHeathcheckThreshold = 5 * time.Minute
)
type BackuperNode struct {
@@ -28,77 +36,93 @@ type BackuperNode struct {
backupConfigService *backups_config.BackupConfigService
storageService *storages.StorageService
notificationSender backups_core.NotificationSender
backupCancelManager *backups_cancellation.BackupCancelManager
nodesRegistry *BackupNodesRegistry
backupCancelManager *tasks_cancellation.TaskCancelManager
backupNodesRegistry *BackupNodesRegistry
logger *slog.Logger
createBackupUseCase backups_core.CreateBackupUsecase
nodeID uuid.UUID
lastHeartbeat time.Time
runOnce sync.Once
hasRun atomic.Bool
}
func (n *BackuperNode) Run(ctx context.Context) {
n.lastHeartbeat = time.Now().UTC()
wasAlreadyRun := n.hasRun.Load()
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
n.runOnce.Do(func() {
n.hasRun.Store(true)
backupNode := BackupNode{
ID: n.nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: time.Now().UTC(),
}
n.lastHeartbeat = time.Now().UTC()
if err := n.nodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
n.logger.Error("Failed to register node in registry", "error", err)
panic(err)
}
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
n.MakeBackup(backupID, isCallNotifier)
if err := n.nodesRegistry.PublishBackupCompletion(n.nodeID.String(), backupID); err != nil {
n.logger.Error(
"Failed to publish backup completion",
"error",
err,
"backupID",
backupID,
)
backupNode := BackupNode{
ID: n.nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: time.Now().UTC(),
}
}
if err := n.nodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID.String(), backupHandler); err != nil {
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
panic(err)
}
defer func() {
if err := n.nodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
n.logger.Error("Failed to register node in registry", "error", err)
panic(err)
}
}()
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
go func() {
n.MakeBackup(backupID, isCallNotifier)
if err := n.backupNodesRegistry.PublishBackupCompletion(n.nodeID, backupID); err != nil {
n.logger.Error(
"Failed to publish backup completion",
"error",
err,
"backupID",
backupID,
)
}
}()
}
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
for {
select {
case <-ctx.Done():
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
if err := n.nodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
n.logger.Error("Failed to unregister node from registry", "error", err)
err := n.backupNodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID, backupHandler)
if err != nil {
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
panic(err)
}
defer func() {
if err := n.backupNodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
}
}()
return
case <-ticker.C:
n.sendHeartbeat(&backupNode)
ticker := time.NewTicker(heartbeatTickerInterval)
defer ticker.Stop()
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
for {
select {
case <-ctx.Done():
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
if err := n.backupNodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
n.logger.Error("Failed to unregister node from registry", "error", err)
}
return
case <-ticker.C:
n.sendHeartbeat(&backupNode)
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", n))
}
}
func (n *BackuperNode) IsBackuperRunning() bool {
return n.lastHeartbeat.After(time.Now().UTC().Add(-5 * time.Minute))
return n.lastHeartbeat.After(time.Now().UTC().Add(-backuperHeathcheckThreshold))
}
func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
@@ -135,21 +159,41 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
start := time.Now().UTC()
ctx, cancel := context.WithCancel(context.Background())
n.backupCancelManager.RegisterTask(backup.ID, cancel)
defer n.backupCancelManager.UnregisterTask(backup.ID)
backupProgressListener := func(
completedMBs float64,
) {
backup.BackupSizeMb = completedMBs
backup.BackupDurationMs = time.Since(start).Milliseconds()
// Check size limit (0 = unlimited)
if backupConfig.MaxBackupSizeMB > 0 &&
completedMBs > float64(backupConfig.MaxBackupSizeMB) {
errMsg := fmt.Sprintf(
"backup size (%.2f MB) exceeded maximum allowed size (%d MB)",
completedMBs,
backupConfig.MaxBackupSizeMB,
)
backup.Status = backups_core.BackupStatusFailed
backup.IsSkipRetry = true
backup.FailMessage = &errMsg
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to save backup with size exceeded error", "error", err)
}
cancel() // Cancel the backup context
return
}
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to update backup progress", "error", err)
}
}
ctx, cancel := context.WithCancel(context.Background())
n.backupCancelManager.RegisterBackup(backup.ID, cancel)
defer n.backupCancelManager.UnregisterBackup(backup.ID)
backupMetadata, err := n.createBackupUseCase.Execute(
ctx,
backup.ID,
@@ -159,8 +203,42 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
backupProgressListener,
)
if err != nil {
// Check if backup was already marked as failed by progress listener (e.g., size limit exceeded)
// If so, skip error handling to avoid overwriting the status
currentBackup, fetchErr := n.backupRepository.FindByID(backup.ID)
if fetchErr == nil && currentBackup.Status == backups_core.BackupStatusFailed {
n.logger.Warn(
"Backup already marked as failed by progress listener, skipping error handling",
"backupId",
backup.ID,
"failMessage",
*currentBackup.FailMessage,
)
// Still call notification for size limit failures
n.SendBackupNotification(
backupConfig,
currentBackup,
backups_config.NotificationBackupFailed,
currentBackup.FailMessage,
)
return
}
errMsg := err.Error()
// Log detailed error information for debugging
n.logger.Error("Backup execution failed",
"backupId", backup.ID,
"databaseId", databaseID,
"databaseType", database.Type,
"storageId", storage.ID,
"storageType", storage.Type,
"error", err,
"errorMessage", errMsg,
)
// Check if backup was cancelled (not due to shutdown)
isCancelled := strings.Contains(errMsg, "backup cancelled") ||
strings.Contains(errMsg, "context canceled") ||
@@ -168,6 +246,12 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
isShutdown := strings.Contains(errMsg, "shutdown")
if isCancelled && !isShutdown {
n.logger.Warn("Backup was cancelled by user or system",
"backupId", backup.ID,
"isCancelled", isCancelled,
"isShutdown", isShutdown,
)
backup.Status = backups_core.BackupStatusCanceled
backup.BackupDurationMs = time.Since(start).Milliseconds()
backup.BackupSizeMb = 0
@@ -337,8 +421,7 @@ func (n *BackuperNode) SendBackupNotification(
func (n *BackuperNode) sendHeartbeat(backupNode *BackupNode) {
n.lastHeartbeat = time.Now().UTC()
backupNode.LastHeartbeat = time.Now().UTC()
if err := n.nodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *backupNode); err != nil {
if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *backupNode); err != nil {
n.logger.Error("Failed to send heartbeat", "error", err)
}
}

View File

@@ -1,13 +1,10 @@
package backuping
import (
"context"
"errors"
"strings"
"testing"
"time"
common "databasus-backend/internal/features/backups/backups/common"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -18,7 +15,6 @@ import (
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
cache_utils "databasus-backend/internal/util/cache"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
@@ -158,35 +154,120 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
})
}
type CreateFailedBackupUsecase struct {
}
func Test_BackupSizeLimits(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
func (uc *CreateFailedBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
backupProgressListener(10)
return nil, errors.New("backup failed")
}
defer func() {
// cleanup backups first
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
type CreateSuccessBackupUsecase struct{}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
func (uc *CreateSuccessBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
backupProgressListener(10)
return &common.BackupMetadata{
EncryptionSalt: nil,
EncryptionIV: nil,
Encryption: backups_config.BackupEncryptionNone,
}, nil
t.Run("UnlimitedSize_MaxBackupSizeMBIsZero_BackupCompletes", func(t *testing.T) {
// Enable backups with unlimited size (0)
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
backupConfig.MaxBackupSizeMB = 0 // unlimited
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backuperNode := CreateTestBackuperNode()
backuperNode.createBackupUseCase = &CreateLargeBackupUsecase{}
// Create a backup record
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
backuperNode.MakeBackup(backup.ID, false)
// Verify backup completed successfully even with large size
updatedBackup, err := backupRepository.FindByID(backup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusCompleted, updatedBackup.Status)
assert.Equal(t, float64(10000), updatedBackup.BackupSizeMb)
assert.Nil(t, updatedBackup.FailMessage)
})
t.Run("SizeExceeded_BackupFailedWithIsSkipRetry", func(t *testing.T) {
// Enable backups with 5 MB limit
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
backupConfig.MaxBackupSizeMB = 5
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backuperNode := CreateTestBackuperNode()
backuperNode.createBackupUseCase = &CreateProgressiveBackupUsecase{}
// Create a backup record
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
backuperNode.MakeBackup(backup.ID, false)
// Verify backup was marked as failed with IsSkipRetry=true
updatedBackup, err := backupRepository.FindByID(backup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusFailed, updatedBackup.Status)
assert.True(t, updatedBackup.IsSkipRetry)
assert.NotNil(t, updatedBackup.FailMessage)
assert.Contains(t, *updatedBackup.FailMessage, "exceeded maximum allowed size")
assert.Contains(t, *updatedBackup.FailMessage, "10.00 MB")
assert.Contains(t, *updatedBackup.FailMessage, "5 MB")
assert.Greater(t, updatedBackup.BackupSizeMb, float64(5))
})
t.Run("SizeWithinLimit_BackupCompletes", func(t *testing.T) {
// Enable backups with 100 MB limit
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
backupConfig.MaxBackupSizeMB = 100
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backuperNode := CreateTestBackuperNode()
backuperNode.createBackupUseCase = &CreateMediumBackupUsecase{}
// Create a backup record
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
backuperNode.MakeBackup(backup.ID, false)
// Verify backup completed successfully
updatedBackup, err := backupRepository.FindByID(backup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusCompleted, updatedBackup.Status)
assert.Equal(t, float64(50), updatedBackup.BackupSizeMb)
assert.Nil(t, updatedBackup.FailMessage)
})
}

View File

@@ -0,0 +1,242 @@
package backuping
import (
"context"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/storages"
util_encryption "databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/period"
)
const (
cleanerTickerInterval = 1 * time.Minute
)
type BackupCleaner struct {
backupRepository *backups_core.BackupRepository
storageService *storages.StorageService
backupConfigService *backups_config.BackupConfigService
fieldEncryptor util_encryption.FieldEncryptor
logger *slog.Logger
backupRemoveListeners []backups_core.BackupRemoveListener
runOnce sync.Once
hasRun atomic.Bool
}
func (c *BackupCleaner) Run(ctx context.Context) {
wasAlreadyRun := c.hasRun.Load()
c.runOnce.Do(func() {
c.hasRun.Store(true)
if ctx.Err() != nil {
return
}
ticker := time.NewTicker(cleanerTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := c.cleanOldBackups(); err != nil {
c.logger.Error("Failed to clean old backups", "error", err)
}
if err := c.cleanExceededBackups(); err != nil {
c.logger.Error("Failed to clean exceeded backups", "error", err)
}
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", c))
}
}
func (c *BackupCleaner) DeleteBackup(backup *backups_core.Backup) error {
for _, listener := range c.backupRemoveListeners {
if err := listener.OnBeforeBackupRemove(backup); err != nil {
return err
}
}
storage, err := c.storageService.GetStorageByID(backup.StorageID)
if err != nil {
return err
}
err = storage.DeleteFile(c.fieldEncryptor, backup.ID)
if err != nil {
// we do not return error here, because sometimes clean up performed
// before unavailable storage removal or change - therefore we should
// proceed even in case of error. It's possible that some S3 or
// storage is not available yet, it should not block us
c.logger.Error("Failed to delete backup file", "error", err)
}
return c.backupRepository.DeleteByID(backup.ID)
}
func (c *BackupCleaner) AddBackupRemoveListener(listener backups_core.BackupRemoveListener) {
c.backupRemoveListeners = append(c.backupRemoveListeners, listener)
}
func (c *BackupCleaner) cleanOldBackups() error {
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
backupStorePeriod := backupConfig.StorePeriod
if backupStorePeriod == period.PeriodForever {
continue
}
storeDuration := backupStorePeriod.ToDuration()
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
oldBackups, err := c.backupRepository.FindBackupsBeforeDate(
backupConfig.DatabaseID,
dateBeforeBackupsShouldBeDeleted,
)
if err != nil {
c.logger.Error(
"Failed to find old backups for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
continue
}
for _, backup := range oldBackups {
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
continue
}
c.logger.Info(
"Deleted old backup",
"backupId",
backup.ID,
"databaseId",
backupConfig.DatabaseID,
)
}
}
return nil
}
func (c *BackupCleaner) cleanExceededBackups() error {
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
if backupConfig.MaxBackupsTotalSizeMB <= 0 {
continue
}
if err := c.cleanExceededBackupsForDatabase(
backupConfig.DatabaseID,
backupConfig.MaxBackupsTotalSizeMB,
); err != nil {
c.logger.Error(
"Failed to clean exceeded backups for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
continue
}
}
return nil
}
func (c *BackupCleaner) cleanExceededBackupsForDatabase(
databaseID uuid.UUID,
limitperDbMB int64,
) error {
for {
backupsTotalSizeMB, err := c.backupRepository.GetTotalSizeByDatabase(databaseID)
if err != nil {
return err
}
if backupsTotalSizeMB <= float64(limitperDbMB) {
break
}
oldestBackups, err := c.backupRepository.FindOldestByDatabaseExcludingInProgress(
databaseID,
1,
)
if err != nil {
return err
}
if len(oldestBackups) == 0 {
c.logger.Warn(
"No backups to delete but still over limit",
"databaseId",
databaseID,
"totalSizeMB",
backupsTotalSizeMB,
"limitMB",
limitperDbMB,
)
break
}
backup := oldestBackups[0]
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error(
"Failed to delete exceeded backup",
"backupId",
backup.ID,
"databaseId",
databaseID,
"error",
err,
)
return err
}
c.logger.Info(
"Deleted exceeded backup",
"backupId",
backup.ID,
"databaseId",
databaseID,
"backupSizeMB",
backup.BackupSizeMb,
"totalSizeMB",
backupsTotalSizeMB,
"limitMB",
limitperDbMB,
)
}
return nil
}

View File

@@ -0,0 +1,491 @@
package backuping
import (
"testing"
"time"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/storage"
"databasus-backend/internal/util/period"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_CleanOldBackups_DeletesBackupsOlderThanStorePeriod(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
// Create backup interval
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Create backups with different ages
now := time.Now().UTC()
oldBackup1 := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10,
CreatedAt: now.Add(-10 * 24 * time.Hour), // 10 days old
}
oldBackup2 := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10,
CreatedAt: now.Add(-8 * 24 * time.Hour), // 8 days old
}
recentBackup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10,
CreatedAt: now.Add(-3 * 24 * time.Hour), // 3 days old
}
err = backupRepository.Save(oldBackup1)
assert.NoError(t, err)
err = backupRepository.Save(oldBackup2)
assert.NoError(t, err)
err = backupRepository.Save(recentBackup)
assert.NoError(t, err)
// Run cleanup
cleaner := GetBackupCleaner()
err = cleaner.cleanOldBackups()
assert.NoError(t, err)
// Verify old backups deleted, recent backup remains
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Equal(t, 1, len(remainingBackups))
assert.Equal(t, recentBackup.ID, remainingBackups[0].ID)
}
func Test_CleanOldBackups_SkipsDatabaseWithForeverStorePeriod(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
// Create backup interval
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodForever,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Create very old backup
oldBackup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10,
CreatedAt: time.Now().UTC().Add(-365 * 24 * time.Hour), // 1 year old
}
err = backupRepository.Save(oldBackup)
assert.NoError(t, err)
// Run cleanup
cleaner := GetBackupCleaner()
err = cleaner.cleanOldBackups()
assert.NoError(t, err)
// Verify backup still exists
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Equal(t, 1, len(remainingBackups))
assert.Equal(t, oldBackup.ID, remainingBackups[0].ID)
}
func Test_CleanExceededBackups_WhenUnderLimit_NoBackupsDeleted(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
// Create backup interval
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 100, // 100 MB limit
BackupIntervalID: interval.ID,
BackupInterval: interval,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Create 3 backups totaling 50MB (under limit)
for i := 0; i < 3; i++ {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 16.67,
CreatedAt: time.Now().UTC().Add(-time.Duration(i) * time.Hour),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
}
// Run cleanup
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
assert.NoError(t, err)
// Verify all backups remain
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Equal(t, 3, len(remainingBackups))
}
func Test_CleanExceededBackups_WhenOverLimit_DeletesOldestBackups(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
// Create backup interval
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 30, // 30 MB limit
BackupIntervalID: interval.ID,
BackupInterval: interval,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Create 5 backups of 10MB each (total 50MB, over 30MB limit)
now := time.Now().UTC()
var backupIDs []uuid.UUID
for i := 0; i < 5; i++ {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10,
CreatedAt: now.Add(-time.Duration(4-i) * time.Hour), // Oldest first
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
backupIDs = append(backupIDs, backup.ID)
}
// Run cleanup
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
assert.NoError(t, err)
// Verify 2 oldest backups deleted, 3 newest remain
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Equal(t, 3, len(remainingBackups))
// Check that the newest 3 backups remain
remainingIDs := make(map[uuid.UUID]bool)
for _, backup := range remainingBackups {
remainingIDs[backup.ID] = true
}
assert.False(t, remainingIDs[backupIDs[0]]) // Oldest deleted
assert.False(t, remainingIDs[backupIDs[1]]) // 2nd oldest deleted
assert.True(t, remainingIDs[backupIDs[2]]) // 3rd remains
assert.True(t, remainingIDs[backupIDs[3]]) // 4th remains
assert.True(t, remainingIDs[backupIDs[4]]) // Newest remains
}
func Test_CleanExceededBackups_SkipsInProgressBackups(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
// Create backup interval
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 50, // 50 MB limit
BackupIntervalID: interval.ID,
BackupInterval: interval,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
now := time.Now().UTC()
// Create 3 completed backups of 30MB each
completedBackups := make([]*backups_core.Backup, 3)
for i := 0; i < 3; i++ {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 30,
CreatedAt: now.Add(-time.Duration(3-i) * time.Hour),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
completedBackups[i] = backup
}
// Create 1 in-progress backup (should be excluded from size calculation and deletion)
inProgressBackup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
BackupSizeMb: 10,
CreatedAt: now,
}
err = backupRepository.Save(inProgressBackup)
assert.NoError(t, err)
// Run cleanup
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
assert.NoError(t, err)
// Verify: only completed backups deleted, in-progress remains
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
// Should have in-progress + 1 completed (total 40MB completed + 10MB in-progress)
assert.GreaterOrEqual(t, len(remainingBackups), 2)
// Verify in-progress backup still exists
var inProgressFound bool
for _, backup := range remainingBackups {
if backup.ID == inProgressBackup.ID {
inProgressFound = true
assert.Equal(t, backups_core.BackupStatusInProgress, backup.Status)
}
}
assert.True(t, inProgressFound, "In-progress backup should not be deleted")
}
func Test_CleanExceededBackups_WithZeroLimit_SkipsDatabase(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
// Create backup interval
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodForever,
StorageID: &storage.ID,
MaxBackupsTotalSizeMB: 0, // No size limit
BackupIntervalID: interval.ID,
BackupInterval: interval,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Create large backups
for i := 0; i < 10; i++ {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 100,
CreatedAt: time.Now().UTC().Add(-time.Duration(i) * time.Hour),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
}
// Run cleanup
cleaner := GetBackupCleaner()
err = cleaner.cleanExceededBackups()
assert.NoError(t, err)
// Verify all backups remain
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Equal(t, 10, len(remainingBackups))
}
func Test_GetTotalSizeByDatabase_CalculatesCorrectly(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
// Create completed backups
completedBackup1 := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10.5,
CreatedAt: time.Now().UTC(),
}
completedBackup2 := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 20.3,
CreatedAt: time.Now().UTC(),
}
// Create failed backup (should be included)
failedBackup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusFailed,
BackupSizeMb: 5.2,
CreatedAt: time.Now().UTC(),
}
// Create in-progress backup (should be excluded)
inProgressBackup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
BackupSizeMb: 100,
CreatedAt: time.Now().UTC(),
}
err := backupRepository.Save(completedBackup1)
assert.NoError(t, err)
err = backupRepository.Save(completedBackup2)
assert.NoError(t, err)
err = backupRepository.Save(failedBackup)
assert.NoError(t, err)
err = backupRepository.Save(inProgressBackup)
assert.NoError(t, err)
// Calculate total size
totalSize, err := backupRepository.GetTotalSizeByDatabase(database.ID)
assert.NoError(t, err)
// Should be 10.5 + 20.3 + 5.2 = 36.0 (excluding in-progress 100)
assert.InDelta(t, 36.0, totalSize, 0.1)
}
// Mock listener for testing
type mockBackupRemoveListener struct {
onBeforeBackupRemove func(*backups_core.Backup) error
}
func (m *mockBackupRemoveListener) OnBeforeBackupRemove(backup *backups_core.Backup) error {
if m.onBeforeBackupRemove != nil {
return m.onBeforeBackupRemove(backup)
}
return nil
}
// Test_DeleteBackup_WhenStorageDeleteFails_BackupStillRemovedFromDatabase verifies resilience
// when storage becomes unavailable. Even if storage.DeleteFile fails (e.g., storage is offline,
// credentials changed, or storage was deleted), the backup record should still be removed from
// the database. This prevents orphaned backup records when storage is no longer accessible.
func Test_DeleteBackup_WhenStorageDeleteFails_BackupStillRemovedFromDatabase(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
testStorage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, testStorage, notifier)
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: testStorage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10,
CreatedAt: time.Now().UTC(),
}
err := backupRepository.Save(backup)
assert.NoError(t, err)
cleaner := GetBackupCleaner()
err = cleaner.DeleteBackup(backup)
assert.NoError(t, err, "DeleteBackup should succeed even when storage file doesn't exist")
deletedBackup, err := backupRepository.FindByID(backup.ID)
assert.Error(t, err, "Backup should not exist in database")
assert.Nil(t, deletedBackup)
}
func createTestInterval() *intervals.Interval {
timeOfDay := "04:00"
interval := &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
err := storage.GetDb().Create(interval).Error
if err != nil {
panic(err)
}
return interval
}

View File

@@ -1,71 +1,83 @@
package backuping
import (
"databasus-backend/internal/config"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/usecases"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
workspaces_services "databasus-backend/internal/features/workspaces/services"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
"time"
"github.com/google/uuid"
)
var backupRepository = &backups_core.BackupRepository{}
var backupCancelManager = backups_cancellation.GetBackupCancelManager()
var taskCancelManager = tasks_cancellation.GetTaskCancelManager()
var nodesRegistry = &BackupNodesRegistry{
cache_utils.GetValkeyClient(),
logger.GetLogger(),
cache_utils.DefaultCacheTimeout,
cache_utils.NewPubSubManager(),
cache_utils.NewPubSubManager(),
var backupCleaner = &BackupCleaner{
backupRepository: backupRepository,
storageService: storages.GetStorageService(),
backupConfigService: backups_config.GetBackupConfigService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
logger: logger.GetLogger(),
backupRemoveListeners: []backups_core.BackupRemoveListener{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
var backupNodesRegistry = &BackupNodesRegistry{
client: cache_utils.GetValkeyClient(),
logger: logger.GetLogger(),
timeout: cache_utils.DefaultCacheTimeout,
pubsubBackups: cache_utils.NewPubSubManager(),
pubsubCompletions: cache_utils.NewPubSubManager(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
func getNodeID() uuid.UUID {
nodeIDStr := config.GetEnv().NodeID
nodeID, err := uuid.Parse(nodeIDStr)
if err != nil {
logger.GetLogger().Error("Failed to parse node ID from config", "error", err)
panic(err)
}
return nodeID
return uuid.New()
}
var backuperNode = &BackuperNode{
databases.GetDatabaseService(),
encryption.GetFieldEncryptor(),
workspaces_services.GetWorkspaceService(),
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
notifiers.GetNotifierService(),
backupCancelManager,
nodesRegistry,
logger.GetLogger(),
usecases.GetCreateBackupUsecase(),
getNodeID(),
time.Time{},
databaseService: databases.GetDatabaseService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
workspaceService: workspaces_services.GetWorkspaceService(),
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
notificationSender: notifiers.GetNotifierService(),
backupCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
logger: logger.GetLogger(),
createBackupUseCase: usecases.GetCreateBackupUsecase(),
nodeID: getNodeID(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
var backupsScheduler = &BackupsScheduler{
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
backupCancelManager,
nodesRegistry,
time.Now().UTC(),
logger.GetLogger(),
make(map[uuid.UUID]BackupToNodeRelation),
backuperNode,
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
taskCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
lastBackupTime: time.Now().UTC(),
logger: logger.GetLogger(),
backupToNodeRelations: make(map[uuid.UUID]BackupToNodeRelation),
backuperNode: backuperNode,
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
func GetBackupsScheduler() *BackupsScheduler {
@@ -75,3 +87,11 @@ func GetBackupsScheduler() *BackupsScheduler {
func GetBackuperNode() *BackuperNode {
return backuperNode
}
func GetBackupNodesRegistry() *BackupNodesRegistry {
return backupNodesRegistry
}
func GetBackupCleaner() *BackupCleaner {
return backupCleaner
}

View File

@@ -6,6 +6,11 @@ import (
"github.com/google/uuid"
)
type BackupToNodeRelation struct {
NodeID uuid.UUID `json:"nodeId"`
BackupsIDs []uuid.UUID `json:"backupsIds"`
}
type BackupNode struct {
ID uuid.UUID `json:"id"`
ThroughputMBs int `json:"throughputMBs"`
@@ -18,17 +23,12 @@ type BackupNodeStats struct {
}
type BackupSubmitMessage struct {
NodeID string `json:"nodeId"`
BackupID string `json:"backupId"`
IsCallNotifier bool `json:"isCallNotifier"`
NodeID uuid.UUID `json:"nodeId"`
BackupID uuid.UUID `json:"backupId"`
IsCallNotifier bool `json:"isCallNotifier"`
}
type BackupCompletionMessage struct {
NodeID string `json:"nodeId"`
BackupID string `json:"backupId"`
}
type BackupToNodeRelation struct {
NodeID uuid.UUID `json:"nodeId"`
BackupsIDs []uuid.UUID `json:"backupsIds"`
NodeID uuid.UUID `json:"nodeId"`
BackupID uuid.UUID `json:"backupId"`
}

View File

@@ -1,8 +1,18 @@
package backuping
import (
"databasus-backend/internal/features/notifiers"
"context"
"errors"
"sync/atomic"
"time"
common "databasus-backend/internal/features/backups/backups/common"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
"github.com/google/uuid"
"github.com/stretchr/testify/mock"
)
@@ -17,3 +27,168 @@ func (m *MockNotificationSender) SendNotification(
) {
m.Called(notifier, title, message)
}
type CreateFailedBackupUsecase struct{}
func (uc *CreateFailedBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
backupProgressListener(10)
return nil, errors.New("backup failed")
}
type CreateSuccessBackupUsecase struct{}
func (uc *CreateSuccessBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
backupProgressListener(10)
return &common.BackupMetadata{
EncryptionSalt: nil,
EncryptionIV: nil,
Encryption: backups_config.BackupEncryptionNone,
}, nil
}
// CreateLargeBackupUsecase simulates a large backup (10000 MB)
type CreateLargeBackupUsecase struct{}
func (uc *CreateLargeBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
backupProgressListener(10000)
return &common.BackupMetadata{
EncryptionSalt: nil,
EncryptionIV: nil,
Encryption: backups_config.BackupEncryptionNone,
}, nil
}
// CreateProgressiveBackupUsecase simulates progressive size updates that exceed limit
type CreateProgressiveBackupUsecase struct{}
func (uc *CreateProgressiveBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
// Simulate progressive backup that grows beyond limit
backupProgressListener(1)
if ctx.Err() != nil {
return nil, ctx.Err()
}
backupProgressListener(3)
if ctx.Err() != nil {
return nil, ctx.Err()
}
backupProgressListener(5)
if ctx.Err() != nil {
return nil, ctx.Err()
}
backupProgressListener(10) // This exceeds the 5 MB limit
if ctx.Err() != nil {
return nil, ctx.Err()
}
// Should not reach here due to cancellation
return &common.BackupMetadata{
EncryptionSalt: nil,
EncryptionIV: nil,
Encryption: backups_config.BackupEncryptionNone,
}, nil
}
// CreateMediumBackupUsecase simulates a 50 MB backup
type CreateMediumBackupUsecase struct{}
func (uc *CreateMediumBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
backupProgressListener(50)
return &common.BackupMetadata{
EncryptionSalt: nil,
EncryptionIV: nil,
Encryption: backups_config.BackupEncryptionNone,
}, nil
}
// MockTrackingBackupUsecase tracks backup use case calls for testing parallel execution
type MockTrackingBackupUsecase struct {
callCount atomic.Int32
calledBackupIDs chan uuid.UUID
}
func NewMockTrackingBackupUsecase() *MockTrackingBackupUsecase {
return &MockTrackingBackupUsecase{
calledBackupIDs: make(chan uuid.UUID, 10),
}
}
func (m *MockTrackingBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
m.callCount.Add(1)
// Send backup ID to channel (non-blocking)
select {
case m.calledBackupIDs <- backupID:
default:
}
// Simulate backup work
time.Sleep(100 * time.Millisecond)
backupProgressListener(10)
return &common.BackupMetadata{
EncryptionSalt: nil,
EncryptionIV: nil,
Encryption: backups_config.BackupEncryptionNone,
}, nil
}
func (m *MockTrackingBackupUsecase) GetCallCount() int32 {
return m.callCount.Load()
}
func (m *MockTrackingBackupUsecase) GetCalledBackupIDs() []uuid.UUID {
ids := []uuid.UUID{}
for {
select {
case id := <-m.calledBackupIDs:
ids = append(ids, id)
default:
return ids
}
}
}

View File

@@ -6,6 +6,8 @@ import (
"fmt"
"log/slog"
"strings"
"sync"
"sync/atomic"
"time"
cache_utils "databasus-backend/internal/util/cache"
@@ -15,25 +17,70 @@ import (
)
const (
nodeInfoKeyPrefix = "node:"
nodeInfoKeyPrefix = "backup:node:"
nodeInfoKeySuffix = ":info"
nodeActiveBackupsPrefix = "node:"
nodeActiveBackupsPrefix = "backup:node:"
nodeActiveBackupsSuffix = ":active_backups"
backupSubmitChannel = "backup:submit"
backupCompletionChannel = "backup:completion"
deadNodeThreshold = 2 * time.Minute
cleanupTickerInterval = 1 * time.Second
)
// BackupNodesRegistry helps to sync backups scheduler and backup nodes.
//
// Features:
// - Track node availability and load level
// - Assign from scheduler to node backups needed to be processed
// - Notify scheduler from node about backup completion
//
// Important things to remember:
// - Nodes without heartbeat for more than 2 minutes are not included
// in available nodes list and stats
//
// Cleanup dead nodes performed on 2 levels:
// - List and stats functions do not return dead nodes
// - Periodically dead nodes are cleaned up in cache (to not
// accumulate too many dead nodes in cache)
type BackupNodesRegistry struct {
client valkey.Client
logger *slog.Logger
timeout time.Duration
pubsubBackups *cache_utils.PubSubManager
pubsubCompletions *cache_utils.PubSubManager
runOnce sync.Once
hasRun atomic.Bool
}
func (r *BackupNodesRegistry) Run(ctx context.Context) {
wasAlreadyRun := r.hasRun.Load()
r.runOnce.Do(func() {
r.hasRun.Store(true)
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
}
ticker := time.NewTicker(cleanupTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes", "error", err)
}
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", r))
}
}
func (r *BackupNodesRegistry) GetAvailableNodes() ([]BackupNode, error) {
@@ -76,13 +123,30 @@ func (r *BackupNodesRegistry) GetAvailableNodes() ([]BackupNode, error) {
return nil, fmt.Errorf("failed to pipeline get node keys: %w", err)
}
threshold := time.Now().UTC().Add(-deadNodeThreshold)
var nodes []BackupNode
for key, data := range keyDataMap {
// Skip if the key doesn't exist (data is empty)
if len(data) == 0 {
continue
}
var node BackupNode
if err := json.Unmarshal(data, &node); err != nil {
r.logger.Warn("Failed to unmarshal node data", "key", key, "error", err)
continue
}
// Skip nodes with zero/uninitialized heartbeat
if node.LastHeartbeat.IsZero() {
continue
}
if node.LastHeartbeat.Before(threshold) {
continue
}
nodes = append(nodes, node)
}
@@ -129,18 +193,54 @@ func (r *BackupNodesRegistry) GetBackupNodesStats() ([]BackupNodeStats, error) {
return nil, fmt.Errorf("failed to pipeline get active backups keys: %w", err)
}
var stats []BackupNodeStats
for key, data := range keyDataMap {
var nodeInfoKeys []string
nodeIDToStatsKey := make(map[string]string)
for key := range keyDataMap {
nodeID := r.extractNodeIDFromKey(key, nodeActiveBackupsPrefix, nodeActiveBackupsSuffix)
nodeIDStr := nodeID.String()
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeIDStr, nodeInfoKeySuffix)
nodeInfoKeys = append(nodeInfoKeys, infoKey)
nodeIDToStatsKey[infoKey] = key
}
count, err := r.parseIntFromBytes(data)
nodeInfoMap, err := r.pipelineGetKeys(nodeInfoKeys)
if err != nil {
return nil, fmt.Errorf("failed to pipeline get node info keys: %w", err)
}
threshold := time.Now().UTC().Add(-deadNodeThreshold)
var stats []BackupNodeStats
for infoKey, nodeData := range nodeInfoMap {
// Skip if the info key doesn't exist (nodeData is empty)
if len(nodeData) == 0 {
continue
}
var node BackupNode
if err := json.Unmarshal(nodeData, &node); err != nil {
r.logger.Warn("Failed to unmarshal node data", "key", infoKey, "error", err)
continue
}
// Skip nodes with zero/uninitialized heartbeat
if node.LastHeartbeat.IsZero() {
continue
}
if node.LastHeartbeat.Before(threshold) {
continue
}
statsKey := nodeIDToStatsKey[infoKey]
tasksData := keyDataMap[statsKey]
count, err := r.parseIntFromBytes(tasksData)
if err != nil {
r.logger.Warn("Failed to parse active backups count", "key", key, "error", err)
r.logger.Warn("Failed to parse active backups count", "key", statsKey, "error", err)
continue
}
stat := BackupNodeStats{
ID: nodeID,
ID: node.ID,
ActiveBackups: int(count),
}
stats = append(stats, stat)
@@ -149,11 +249,11 @@ func (r *BackupNodesRegistry) GetBackupNodesStats() ([]BackupNodeStats, error) {
return stats, nil
}
func (r *BackupNodesRegistry) IncrementBackupsInProgress(nodeID string) error {
func (r *BackupNodesRegistry) IncrementBackupsInProgress(nodeID uuid.UUID) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID, nodeActiveBackupsSuffix)
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID.String(), nodeActiveBackupsSuffix)
result := r.client.Do(ctx, r.client.B().Incr().Key(key).Build())
if result.Error() != nil {
@@ -167,11 +267,11 @@ func (r *BackupNodesRegistry) IncrementBackupsInProgress(nodeID string) error {
return nil
}
func (r *BackupNodesRegistry) DecrementBackupsInProgress(nodeID string) error {
func (r *BackupNodesRegistry) DecrementBackupsInProgress(nodeID uuid.UUID) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID, nodeActiveBackupsSuffix)
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID.String(), nodeActiveBackupsSuffix)
result := r.client.Do(ctx, r.client.B().Decr().Key(key).Build())
if result.Error() != nil {
@@ -198,6 +298,10 @@ func (r *BackupNodesRegistry) DecrementBackupsInProgress(nodeID string) error {
}
func (r *BackupNodesRegistry) HearthbeatNodeInRegistry(now time.Time, backupNode BackupNode) error {
if now.IsZero() {
return fmt.Errorf("cannot register node with zero heartbeat timestamp")
}
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
@@ -247,7 +351,7 @@ func (r *BackupNodesRegistry) UnregisterNodeFromRegistry(backupNode BackupNode)
}
func (r *BackupNodesRegistry) AssignBackupToNode(
targetNodeID string,
targetNodeID uuid.UUID,
backupID uuid.UUID,
isCallNotifier bool,
) error {
@@ -255,7 +359,7 @@ func (r *BackupNodesRegistry) AssignBackupToNode(
message := BackupSubmitMessage{
NodeID: targetNodeID,
BackupID: backupID.String(),
BackupID: backupID,
IsCallNotifier: isCallNotifier,
}
@@ -273,7 +377,7 @@ func (r *BackupNodesRegistry) AssignBackupToNode(
}
func (r *BackupNodesRegistry) SubscribeNodeForBackupsAssignment(
nodeID string,
nodeID uuid.UUID,
handler func(backupID uuid.UUID, isCallNotifier bool),
) error {
ctx := context.Background()
@@ -289,19 +393,7 @@ func (r *BackupNodesRegistry) SubscribeNodeForBackupsAssignment(
return
}
backupID, err := uuid.Parse(msg.BackupID)
if err != nil {
r.logger.Warn(
"Failed to parse backup ID from message",
"backupId",
msg.BackupID,
"error",
err,
)
return
}
handler(backupID, msg.IsCallNotifier)
handler(msg.BackupID, msg.IsCallNotifier)
}
err := r.pubsubBackups.Subscribe(ctx, backupSubmitChannel, wrappedHandler)
@@ -323,12 +415,12 @@ func (r *BackupNodesRegistry) UnsubscribeNodeForBackupsAssignments() error {
return nil
}
func (r *BackupNodesRegistry) PublishBackupCompletion(nodeID string, backupID uuid.UUID) error {
func (r *BackupNodesRegistry) PublishBackupCompletion(nodeID uuid.UUID, backupID uuid.UUID) error {
ctx := context.Background()
message := BackupCompletionMessage{
NodeID: nodeID,
BackupID: backupID.String(),
BackupID: backupID,
}
messageJSON, err := json.Marshal(message)
@@ -345,7 +437,7 @@ func (r *BackupNodesRegistry) PublishBackupCompletion(nodeID string, backupID uu
}
func (r *BackupNodesRegistry) SubscribeForBackupsCompletions(
handler func(nodeID string, backupID uuid.UUID),
handler func(nodeID uuid.UUID, backupID uuid.UUID),
) error {
ctx := context.Background()
@@ -356,19 +448,7 @@ func (r *BackupNodesRegistry) SubscribeForBackupsCompletions(
return
}
backupID, err := uuid.Parse(msg.BackupID)
if err != nil {
r.logger.Warn(
"Failed to parse backup ID from completion message",
"backupId",
msg.BackupID,
"error",
err,
)
return
}
handler(msg.NodeID, backupID)
handler(msg.NodeID, msg.BackupID)
}
err := r.pubsubCompletions.Subscribe(ctx, backupCompletionChannel, wrappedHandler)
@@ -446,3 +526,108 @@ func (r *BackupNodesRegistry) parseIntFromBytes(data []byte) (int64, error) {
}
return count, nil
}
func (r *BackupNodesRegistry) cleanupDeadNodes() error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
var allKeys []string
cursor := uint64(0)
pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix
for {
result := r.client.Do(
ctx,
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(),
)
if result.Error() != nil {
return fmt.Errorf("failed to scan node keys: %w", result.Error())
}
scanResult, err := result.AsScanEntry()
if err != nil {
return fmt.Errorf("failed to parse scan result: %w", err)
}
allKeys = append(allKeys, scanResult.Elements...)
cursor = scanResult.Cursor
if cursor == 0 {
break
}
}
if len(allKeys) == 0 {
return nil
}
keyDataMap, err := r.pipelineGetKeys(allKeys)
if err != nil {
return fmt.Errorf("failed to pipeline get node keys: %w", err)
}
threshold := time.Now().UTC().Add(-deadNodeThreshold)
var deadNodeKeys []string
for key, data := range keyDataMap {
// Skip if the key doesn't exist (data is empty)
if len(data) == 0 {
continue
}
var node BackupNode
if err := json.Unmarshal(data, &node); err != nil {
r.logger.Warn("Failed to unmarshal node data during cleanup", "key", key, "error", err)
continue
}
// Skip nodes with zero/uninitialized heartbeat
if node.LastHeartbeat.IsZero() {
continue
}
if node.LastHeartbeat.Before(threshold) {
nodeID := node.ID.String()
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeID, nodeInfoKeySuffix)
statsKey := fmt.Sprintf(
"%s%s%s",
nodeActiveBackupsPrefix,
nodeID,
nodeActiveBackupsSuffix,
)
deadNodeKeys = append(deadNodeKeys, infoKey, statsKey)
r.logger.Info(
"Marking node for cleanup",
"nodeID", nodeID,
"lastHeartbeat", node.LastHeartbeat,
"threshold", threshold,
)
}
}
if len(deadNodeKeys) == 0 {
return nil
}
delCtx, delCancel := context.WithTimeout(context.Background(), r.timeout)
defer delCancel()
result := r.client.Do(
delCtx,
r.client.B().Del().Key(deadNodeKeys...).Build(),
)
if result.Error() != nil {
return fmt.Errorf("failed to delete dead node keys: %w", result.Error())
}
deletedCount, err := result.AsInt64()
if err != nil {
return fmt.Errorf("failed to parse deleted count: %w", err)
}
r.logger.Info("Cleaned up dead nodes", "deletedKeysCount", deletedCount)
return nil
}

View File

@@ -2,6 +2,10 @@ package backuping
import (
"context"
"encoding/json"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
@@ -36,7 +40,7 @@ func Test_UnregisterNodeFromRegistry_RemovesNodeAndCounter(t *testing.T) {
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node.ID.String())
err = registry.IncrementBackupsInProgress(node.ID)
assert.NoError(t, err)
err = registry.UnregisterNodeFromRegistry(node)
@@ -100,7 +104,7 @@ func Test_IncrementBackupsInProgress_IncrementsCounter(t *testing.T) {
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node.ID.String())
err = registry.IncrementBackupsInProgress(node.ID)
assert.NoError(t, err)
stats, err := registry.GetBackupNodesStats()
@@ -109,7 +113,7 @@ func Test_IncrementBackupsInProgress_IncrementsCounter(t *testing.T) {
assert.Equal(t, node.ID, stats[0].ID)
assert.Equal(t, 1, stats[0].ActiveBackups)
err = registry.IncrementBackupsInProgress(node.ID.String())
err = registry.IncrementBackupsInProgress(node.ID)
assert.NoError(t, err)
stats, err = registry.GetBackupNodesStats()
@@ -127,25 +131,25 @@ func Test_DecrementBackupsInProgress_DecrementsCounter(t *testing.T) {
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node.ID.String())
err = registry.IncrementBackupsInProgress(node.ID)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node.ID.String())
err = registry.IncrementBackupsInProgress(node.ID)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node.ID.String())
err = registry.IncrementBackupsInProgress(node.ID)
assert.NoError(t, err)
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Equal(t, 3, stats[0].ActiveBackups)
err = registry.DecrementBackupsInProgress(node.ID.String())
err = registry.DecrementBackupsInProgress(node.ID)
assert.NoError(t, err)
stats, err = registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Equal(t, 2, stats[0].ActiveBackups)
err = registry.DecrementBackupsInProgress(node.ID.String())
err = registry.DecrementBackupsInProgress(node.ID)
assert.NoError(t, err)
stats, err = registry.GetBackupNodesStats()
@@ -162,7 +166,7 @@ func Test_DecrementBackupsInProgress_WhenNegative_ResetsToZero(t *testing.T) {
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
err = registry.DecrementBackupsInProgress(node.ID.String())
err = registry.DecrementBackupsInProgress(node.ID)
assert.NoError(t, err)
stats, err := registry.GetBackupNodesStats()
@@ -188,19 +192,19 @@ func Test_GetBackupNodesStats_ReturnsStatsForAllNodes(t *testing.T) {
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node1.ID.String())
err = registry.IncrementBackupsInProgress(node1.ID)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node2.ID.String())
err = registry.IncrementBackupsInProgress(node2.ID)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node2.ID.String())
err = registry.IncrementBackupsInProgress(node2.ID)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node3.ID.String())
err = registry.IncrementBackupsInProgress(node3.ID)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node3.ID.String())
err = registry.IncrementBackupsInProgress(node3.ID)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node3.ID.String())
err = registry.IncrementBackupsInProgress(node3.ID)
assert.NoError(t, err)
stats, err := registry.GetBackupNodesStats()
@@ -274,12 +278,12 @@ func Test_BackupCounters_TrackedSeparatelyPerNode(t *testing.T) {
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node1.ID.String())
err = registry.IncrementBackupsInProgress(node1.ID)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node1.ID.String())
err = registry.IncrementBackupsInProgress(node1.ID)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node2.ID.String())
err = registry.IncrementBackupsInProgress(node2.ID)
assert.NoError(t, err)
stats, err := registry.GetBackupNodesStats()
@@ -294,7 +298,7 @@ func Test_BackupCounters_TrackedSeparatelyPerNode(t *testing.T) {
assert.Equal(t, 2, statsMap[node1.ID])
assert.Equal(t, 1, statsMap[node2.ID])
err = registry.DecrementBackupsInProgress(node1.ID.String())
err = registry.DecrementBackupsInProgress(node1.ID)
assert.NoError(t, err)
stats, err = registry.GetBackupNodesStats()
@@ -366,13 +370,239 @@ func Test_HearthbeatNodeInRegistry_UpdatesLastHeartbeat(t *testing.T) {
assert.True(t, nodes[0].LastHeartbeat.After(originalHeartbeat))
}
func Test_HearthbeatNodeInRegistry_RejectsZeroTimestamp(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
err := registry.HearthbeatNodeInRegistry(time.Time{}, node)
assert.Error(t, err)
assert.Contains(t, err.Error(), "zero heartbeat timestamp")
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 0)
}
func Test_GetAvailableNodes_ExcludesStaleNodesFromCache(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node1 := createTestBackupNode()
node2 := createTestBackupNode()
node3 := createTestBackupNode()
defer cleanupTestNode(registry, node1)
defer cleanupTestNode(registry, node2)
defer cleanupTestNode(registry, node3)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
assert.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
defer cancel()
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix)
result := registry.client.Do(ctx, registry.client.B().Get().Key(key).Build())
assert.NoError(t, result.Error())
data, err := result.AsBytes()
assert.NoError(t, err)
var node BackupNode
err = json.Unmarshal(data, &node)
assert.NoError(t, err)
node.LastHeartbeat = time.Now().UTC().Add(-3 * time.Minute)
modifiedData, err := json.Marshal(node)
assert.NoError(t, err)
setCtx, setCancel := context.WithTimeout(context.Background(), registry.timeout)
defer setCancel()
setResult := registry.client.Do(
setCtx,
registry.client.B().Set().Key(key).Value(string(modifiedData)).Build(),
)
assert.NoError(t, setResult.Error())
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 2)
nodeIDs := make(map[uuid.UUID]bool)
for _, n := range nodes {
nodeIDs[n.ID] = true
}
assert.True(t, nodeIDs[node1.ID])
assert.False(t, nodeIDs[node2.ID])
assert.True(t, nodeIDs[node3.ID])
}
func Test_GetBackupNodesStats_ExcludesStaleNodesFromCache(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node1 := createTestBackupNode()
node2 := createTestBackupNode()
node3 := createTestBackupNode()
defer cleanupTestNode(registry, node1)
defer cleanupTestNode(registry, node2)
defer cleanupTestNode(registry, node3)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node1.ID)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node2.ID)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node3.ID)
assert.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
defer cancel()
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix)
result := registry.client.Do(ctx, registry.client.B().Get().Key(key).Build())
assert.NoError(t, result.Error())
data, err := result.AsBytes()
assert.NoError(t, err)
var node BackupNode
err = json.Unmarshal(data, &node)
assert.NoError(t, err)
node.LastHeartbeat = time.Now().UTC().Add(-3 * time.Minute)
modifiedData, err := json.Marshal(node)
assert.NoError(t, err)
setCtx, setCancel := context.WithTimeout(context.Background(), registry.timeout)
defer setCancel()
setResult := registry.client.Do(
setCtx,
registry.client.B().Set().Key(key).Value(string(modifiedData)).Build(),
)
assert.NoError(t, setResult.Error())
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Len(t, stats, 2)
statsMap := make(map[uuid.UUID]int)
for _, stat := range stats {
statsMap[stat.ID] = stat.ActiveBackups
}
assert.Equal(t, 1, statsMap[node1.ID])
_, hasNode2 := statsMap[node2.ID]
assert.False(t, hasNode2)
assert.Equal(t, 1, statsMap[node3.ID])
}
func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node1 := createTestBackupNode()
node2 := createTestBackupNode()
defer cleanupTestNode(registry, node1)
defer cleanupTestNode(registry, node2)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node1.ID)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node2.ID)
assert.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
defer cancel()
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix)
result := registry.client.Do(ctx, registry.client.B().Get().Key(key).Build())
assert.NoError(t, result.Error())
data, err := result.AsBytes()
assert.NoError(t, err)
var node BackupNode
err = json.Unmarshal(data, &node)
assert.NoError(t, err)
node.LastHeartbeat = time.Now().UTC().Add(-3 * time.Minute)
modifiedData, err := json.Marshal(node)
assert.NoError(t, err)
setCtx, setCancel := context.WithTimeout(context.Background(), registry.timeout)
defer setCancel()
setResult := registry.client.Do(
setCtx,
registry.client.B().Set().Key(key).Value(string(modifiedData)).Build(),
)
assert.NoError(t, setResult.Error())
err = registry.cleanupDeadNodes()
assert.NoError(t, err)
checkCtx, checkCancel := context.WithTimeout(context.Background(), registry.timeout)
defer checkCancel()
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix)
infoResult := registry.client.Do(checkCtx, registry.client.B().Get().Key(infoKey).Build())
assert.Error(t, infoResult.Error())
counterKey := fmt.Sprintf(
"%s%s%s",
nodeActiveBackupsPrefix,
node2.ID.String(),
nodeActiveBackupsSuffix,
)
counterCtx, counterCancel := context.WithTimeout(context.Background(), registry.timeout)
defer counterCancel()
counterResult := registry.client.Do(
counterCtx,
registry.client.B().Get().Key(counterKey).Build(),
)
assert.Error(t, counterResult.Error())
activeInfoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node1.ID.String(), nodeInfoKeySuffix)
activeCtx, activeCancel := context.WithTimeout(context.Background(), registry.timeout)
defer activeCancel()
activeResult := registry.client.Do(
activeCtx,
registry.client.B().Get().Key(activeInfoKey).Build(),
)
assert.NoError(t, activeResult.Error())
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 1)
assert.Equal(t, node1.ID, nodes[0].ID)
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Len(t, stats, 1)
assert.Equal(t, node1.ID, stats[0].ID)
}
func createTestRegistry() *BackupNodesRegistry {
return &BackupNodesRegistry{
cache_utils.GetValkeyClient(),
logger.GetLogger(),
cache_utils.DefaultCacheTimeout,
cache_utils.NewPubSubManager(),
cache_utils.NewPubSubManager(),
client: cache_utils.GetValkeyClient(),
logger: logger.GetLogger(),
timeout: cache_utils.DefaultCacheTimeout,
pubsubBackups: cache_utils.NewPubSubManager(),
pubsubCompletions: cache_utils.NewPubSubManager(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}
@@ -394,7 +624,7 @@ func Test_AssignBackupTonode_PublishesJsonMessageToChannel(t *testing.T) {
node := createTestBackupNode()
backupID := uuid.New()
err := registry.AssignBackupToNode(node.ID.String(), backupID, true)
err := registry.AssignBackupToNode(node.ID, backupID, true)
assert.NoError(t, err)
}
@@ -410,12 +640,12 @@ func Test_SubscribeNodeForBackupsAssignment_ReceivesSubmittedBackupsForMatchingN
receivedBackupID <- id
}
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
err := registry.SubscribeNodeForBackupsAssignment(node.ID, handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.AssignBackupToNode(node.ID.String(), backupID, true)
err = registry.AssignBackupToNode(node.ID, backupID, true)
assert.NoError(t, err)
select {
@@ -439,12 +669,12 @@ func Test_SubscribeNodeForBackupsAssignment_FiltersOutBackupsForDifferentNode(t
receivedBackupID <- id
}
err := registry.SubscribeNodeForBackupsAssignment(node1.ID.String(), handler)
err := registry.SubscribeNodeForBackupsAssignment(node1.ID, handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.AssignBackupToNode(node2.ID.String(), backupID, false)
err = registry.AssignBackupToNode(node2.ID, backupID, false)
assert.NoError(t, err)
select {
@@ -467,15 +697,15 @@ func Test_SubscribeNodeForBackupsAssignment_ParsesJsonAndBackupIdCorrectly(t *te
receivedBackups <- id
}
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
err := registry.SubscribeNodeForBackupsAssignment(node.ID, handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.AssignBackupToNode(node.ID.String(), backupID1, true)
err = registry.AssignBackupToNode(node.ID, backupID1, true)
assert.NoError(t, err)
err = registry.AssignBackupToNode(node.ID.String(), backupID2, false)
err = registry.AssignBackupToNode(node.ID, backupID2, false)
assert.NoError(t, err)
received1 := <-receivedBackups
@@ -497,7 +727,7 @@ func Test_SubscribeNodeForBackupsAssignment_HandlesInvalidJson(t *testing.T) {
receivedBackupID <- id
}
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
err := registry.SubscribeNodeForBackupsAssignment(node.ID, handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
@@ -525,12 +755,12 @@ func Test_UnsubscribeNodeForBackupsAssignments_StopsReceivingMessages(t *testing
receivedBackupID <- id
}
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
err := registry.SubscribeNodeForBackupsAssignment(node.ID, handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.AssignBackupToNode(node.ID.String(), backupID1, true)
err = registry.AssignBackupToNode(node.ID, backupID1, true)
assert.NoError(t, err)
received := <-receivedBackupID
@@ -541,7 +771,7 @@ func Test_UnsubscribeNodeForBackupsAssignments_StopsReceivingMessages(t *testing
time.Sleep(100 * time.Millisecond)
err = registry.AssignBackupToNode(node.ID.String(), backupID2, false)
err = registry.AssignBackupToNode(node.ID, backupID2, false)
assert.NoError(t, err)
select {
@@ -559,10 +789,10 @@ func Test_SubscribeNodeForBackupsAssignment_WhenAlreadySubscribed_ReturnsError(t
handler := func(id uuid.UUID, isCallNotifier bool) {}
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
err := registry.SubscribeNodeForBackupsAssignment(node.ID, handler)
assert.NoError(t, err)
err = registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
err = registry.SubscribeNodeForBackupsAssignment(node.ID, handler)
assert.Error(t, err)
assert.Contains(t, err.Error(), "already subscribed")
}
@@ -593,25 +823,25 @@ func Test_MultipleNodes_EachReceivesOnlyTheirBackups(t *testing.T) {
handler2 := func(id uuid.UUID, isCallNotifier bool) { receivedBackups2 <- id }
handler3 := func(id uuid.UUID, isCallNotifier bool) { receivedBackups3 <- id }
err := registry1.SubscribeNodeForBackupsAssignment(node1.ID.String(), handler1)
err := registry1.SubscribeNodeForBackupsAssignment(node1.ID, handler1)
assert.NoError(t, err)
err = registry2.SubscribeNodeForBackupsAssignment(node2.ID.String(), handler2)
err = registry2.SubscribeNodeForBackupsAssignment(node2.ID, handler2)
assert.NoError(t, err)
err = registry3.SubscribeNodeForBackupsAssignment(node3.ID.String(), handler3)
err = registry3.SubscribeNodeForBackupsAssignment(node3.ID, handler3)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
submitRegistry := createTestRegistry()
err = submitRegistry.AssignBackupToNode(node1.ID.String(), backupID1, true)
err = submitRegistry.AssignBackupToNode(node1.ID, backupID1, true)
assert.NoError(t, err)
err = submitRegistry.AssignBackupToNode(node2.ID.String(), backupID2, false)
err = submitRegistry.AssignBackupToNode(node2.ID, backupID2, false)
assert.NoError(t, err)
err = submitRegistry.AssignBackupToNode(node3.ID.String(), backupID3, true)
err = submitRegistry.AssignBackupToNode(node3.ID, backupID3, true)
assert.NoError(t, err)
select {
@@ -660,7 +890,7 @@ func Test_PublishBackupCompletion_PublishesMessageToChannel(t *testing.T) {
node := createTestBackupNode()
backupID := uuid.New()
err := registry.PublishBackupCompletion(node.ID.String(), backupID)
err := registry.PublishBackupCompletion(node.ID, backupID)
assert.NoError(t, err)
}
@@ -672,8 +902,8 @@ func Test_SubscribeForBackupsCompletions_ReceivesCompletedBackups(t *testing.T)
defer registry.UnsubscribeForBackupsCompletions()
receivedBackupID := make(chan uuid.UUID, 1)
receivedNodeID := make(chan string, 1)
handler := func(nodeID string, backupID uuid.UUID) {
receivedNodeID := make(chan uuid.UUID, 1)
handler := func(nodeID uuid.UUID, backupID uuid.UUID) {
receivedNodeID <- nodeID
receivedBackupID <- backupID
}
@@ -683,12 +913,12 @@ func Test_SubscribeForBackupsCompletions_ReceivesCompletedBackups(t *testing.T)
time.Sleep(100 * time.Millisecond)
err = registry.PublishBackupCompletion(node.ID.String(), backupID)
err = registry.PublishBackupCompletion(node.ID, backupID)
assert.NoError(t, err)
select {
case receivedNode := <-receivedNodeID:
assert.Equal(t, node.ID.String(), receivedNode)
assert.Equal(t, node.ID, receivedNode)
case <-time.After(2 * time.Second):
t.Fatal("Timeout waiting for node ID")
}
@@ -710,7 +940,7 @@ func Test_SubscribeForBackupsCompletions_ParsesJsonCorrectly(t *testing.T) {
defer registry.UnsubscribeForBackupsCompletions()
receivedBackups := make(chan uuid.UUID, 2)
handler := func(nodeID string, backupID uuid.UUID) {
handler := func(nodeID uuid.UUID, backupID uuid.UUID) {
receivedBackups <- backupID
}
@@ -719,10 +949,10 @@ func Test_SubscribeForBackupsCompletions_ParsesJsonCorrectly(t *testing.T) {
time.Sleep(100 * time.Millisecond)
err = registry.PublishBackupCompletion(node.ID.String(), backupID1)
err = registry.PublishBackupCompletion(node.ID, backupID1)
assert.NoError(t, err)
err = registry.PublishBackupCompletion(node.ID.String(), backupID2)
err = registry.PublishBackupCompletion(node.ID, backupID2)
assert.NoError(t, err)
received1 := <-receivedBackups
@@ -739,7 +969,7 @@ func Test_SubscribeForBackupsCompletions_HandlesInvalidJson(t *testing.T) {
defer registry.UnsubscribeForBackupsCompletions()
receivedBackupID := make(chan uuid.UUID, 1)
handler := func(nodeID string, backupID uuid.UUID) {
handler := func(nodeID uuid.UUID, backupID uuid.UUID) {
receivedBackupID <- backupID
}
@@ -767,7 +997,7 @@ func Test_UnsubscribeForBackupsCompletions_StopsReceivingMessages(t *testing.T)
backupID2 := uuid.New()
receivedBackupID := make(chan uuid.UUID, 2)
handler := func(nodeID string, backupID uuid.UUID) {
handler := func(nodeID uuid.UUID, backupID uuid.UUID) {
receivedBackupID <- backupID
}
@@ -776,7 +1006,7 @@ func Test_UnsubscribeForBackupsCompletions_StopsReceivingMessages(t *testing.T)
time.Sleep(100 * time.Millisecond)
err = registry.PublishBackupCompletion(node.ID.String(), backupID1)
err = registry.PublishBackupCompletion(node.ID, backupID1)
assert.NoError(t, err)
received := <-receivedBackupID
@@ -787,7 +1017,7 @@ func Test_UnsubscribeForBackupsCompletions_StopsReceivingMessages(t *testing.T)
time.Sleep(100 * time.Millisecond)
err = registry.PublishBackupCompletion(node.ID.String(), backupID2)
err = registry.PublishBackupCompletion(node.ID, backupID2)
assert.NoError(t, err)
select {
@@ -802,7 +1032,7 @@ func Test_SubscribeForBackupsCompletions_WhenAlreadySubscribed_ReturnsError(t *t
registry := createTestRegistry()
defer registry.UnsubscribeForBackupsCompletions()
handler := func(nodeID string, backupID uuid.UUID) {}
handler := func(nodeID uuid.UUID, backupID uuid.UUID) {}
err := registry.SubscribeForBackupsCompletions(handler)
assert.NoError(t, err)
@@ -834,9 +1064,9 @@ func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) {
receivedBackups2 := make(chan uuid.UUID, 3)
receivedBackups3 := make(chan uuid.UUID, 3)
handler1 := func(nodeID string, backupID uuid.UUID) { receivedBackups1 <- backupID }
handler2 := func(nodeID string, backupID uuid.UUID) { receivedBackups2 <- backupID }
handler3 := func(nodeID string, backupID uuid.UUID) { receivedBackups3 <- backupID }
handler1 := func(nodeID uuid.UUID, backupID uuid.UUID) { receivedBackups1 <- backupID }
handler2 := func(nodeID uuid.UUID, backupID uuid.UUID) { receivedBackups2 <- backupID }
handler3 := func(nodeID uuid.UUID, backupID uuid.UUID) { receivedBackups3 <- backupID }
err := registry1.SubscribeForBackupsCompletions(handler1)
assert.NoError(t, err)
@@ -850,13 +1080,13 @@ func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) {
time.Sleep(100 * time.Millisecond)
publishRegistry := createTestRegistry()
err = publishRegistry.PublishBackupCompletion(node1.ID.String(), backupID1)
err = publishRegistry.PublishBackupCompletion(node1.ID, backupID1)
assert.NoError(t, err)
err = publishRegistry.PublishBackupCompletion(node2.ID.String(), backupID2)
err = publishRegistry.PublishBackupCompletion(node2.ID, backupID2)
assert.NoError(t, err)
err = publishRegistry.PublishBackupCompletion(node3.ID.String(), backupID3)
err = publishRegistry.PublishBackupCompletion(node3.ID, backupID3)
assert.NoError(t, err)
receivedAll1 := []uuid.UUID{}

View File

@@ -2,89 +2,105 @@ package backuping
import (
"context"
"databasus-backend/internal/config"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/storages"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/period"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"databasus-backend/internal/config"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
)
const (
schedulerStartupDelay = 1 * time.Minute
schedulerTickerInterval = 1 * time.Minute
schedulerHealthcheckThreshold = 5 * time.Minute
)
type BackupsScheduler struct {
backupRepository *backups_core.BackupRepository
backupConfigService *backups_config.BackupConfigService
storageService *storages.StorageService
backupCancelManager *backups_cancellation.BackupCancelManager
nodesRegistry *BackupNodesRegistry
taskCancelManager *task_cancellation.TaskCancelManager
backupNodesRegistry *BackupNodesRegistry
lastBackupTime time.Time
logger *slog.Logger
backupToNodeRelations map[uuid.UUID]BackupToNodeRelation
backuperNode *BackuperNode
runOnce sync.Once
hasRun atomic.Bool
}
func (s *BackupsScheduler) Run(ctx context.Context) {
s.lastBackupTime = time.Now().UTC()
wasAlreadyRun := s.hasRun.Load()
if config.GetEnv().IsManyNodesMode {
// wait other nodes to start
time.Sleep(1 * time.Minute)
}
s.runOnce.Do(func() {
s.hasRun.Store(true)
if err := s.failBackupsInProgress(); err != nil {
s.logger.Error("Failed to fail backups in progress", "error", err)
panic(err)
}
s.lastBackupTime = time.Now().UTC()
if err := s.nodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted); err != nil {
s.logger.Error("Failed to subscribe to backup completions", "error", err)
panic(err)
}
defer func() {
if err := s.nodesRegistry.UnsubscribeForBackupsCompletions(); err != nil {
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
if config.GetEnv().IsManyNodesMode {
// wait other nodes to start
time.Sleep(schedulerStartupDelay)
}
}()
if ctx.Err() != nil {
return
}
if err := s.failBackupsInProgress(); err != nil {
s.logger.Error("Failed to fail backups in progress", "error", err)
panic(err)
}
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
err := s.backupNodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted)
if err != nil {
s.logger.Error("Failed to subscribe to backup completions", "error", err)
panic(err)
}
for {
select {
case <-ctx.Done():
defer func() {
if err := s.backupNodesRegistry.UnsubscribeForBackupsCompletions(); err != nil {
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
}
}()
if ctx.Err() != nil {
return
case <-ticker.C:
if err := s.cleanOldBackups(); err != nil {
s.logger.Error("Failed to clean old backups", "error", err)
}
if err := s.checkDeadNodesAndFailBackups(); err != nil {
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
}
if err := s.runPendingBackups(); err != nil {
s.logger.Error("Failed to run pending backups", "error", err)
}
s.lastBackupTime = time.Now().UTC()
}
ticker := time.NewTicker(schedulerTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.checkDeadNodesAndFailBackups(); err != nil {
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
}
if err := s.runPendingBackups(); err != nil {
s.logger.Error("Failed to run pending backups", "error", err)
}
s.lastBackupTime = time.Now().UTC()
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
}
func (s *BackupsScheduler) IsSchedulerRunning() bool {
// if last backup time is more than 5 minutes ago, return false
return s.lastBackupTime.After(time.Now().UTC().Add(-5 * time.Minute))
return s.lastBackupTime.After(time.Now().UTC().Add(-schedulerHealthcheckThreshold))
}
func (s *BackupsScheduler) failBackupsInProgress() error {
@@ -93,12 +109,10 @@ func (s *BackupsScheduler) failBackupsInProgress() error {
return err
}
fmt.Println("Backups in progress", len(backupsInProgress))
for _, backup := range backupsInProgress {
if err := s.backupCancelManager.CancelBackup(backup.ID); err != nil {
if err := s.taskCancelManager.CancelTask(backup.ID); err != nil {
s.logger.Error(
"Failed to cancel backup via context manager",
"Failed to cancel backup via task cancel manager",
"backupId",
backup.ID,
"error",
@@ -144,6 +158,33 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
return
}
// Check for existing in-progress backups
inProgressBackups, err := s.backupRepository.FindByDatabaseIdAndStatus(
databaseID,
backups_core.BackupStatusInProgress,
)
if err != nil {
s.logger.Error(
"Failed to check for in-progress backups",
"databaseId",
databaseID,
"error",
err,
)
return
}
if len(inProgressBackups) > 0 {
s.logger.Warn(
"Backup already in progress for database, skipping new backup",
"databaseId",
databaseID,
"existingBackupId",
inProgressBackups[0].ID,
)
return
}
leastBusyNodeID, err := s.calculateLeastBusyNode()
if err != nil {
s.logger.Error(
@@ -175,7 +216,7 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
return
}
if err := s.nodesRegistry.IncrementBackupsInProgress(leastBusyNodeID.String()); err != nil {
if err := s.backupNodesRegistry.IncrementBackupsInProgress(*leastBusyNodeID); err != nil {
s.logger.Error(
"Failed to increment backups in progress",
"nodeId",
@@ -188,7 +229,7 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
return
}
if err := s.nodesRegistry.AssignBackupToNode(leastBusyNodeID.String(), backup.ID, isCallNotifier); err != nil {
if err := s.backupNodesRegistry.AssignBackupToNode(*leastBusyNodeID, backup.ID, isCallNotifier); err != nil {
s.logger.Error(
"Failed to submit backup",
"nodeId",
@@ -198,7 +239,7 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
"error",
err,
)
if decrementErr := s.nodesRegistry.DecrementBackupsInProgress(leastBusyNodeID.String()); decrementErr != nil {
if decrementErr := s.backupNodesRegistry.DecrementBackupsInProgress(*leastBusyNodeID); decrementErr != nil {
s.logger.Error(
"Failed to decrement backups in progress after submit failure",
"nodeId",
@@ -244,6 +285,10 @@ func (s *BackupsScheduler) GetRemainedBackupTryCount(lastBackup *backups_core.Ba
return 0
}
if lastBackup.IsSkipRetry {
return 0
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(lastBackup.DatabaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
@@ -276,74 +321,6 @@ func (s *BackupsScheduler) GetRemainedBackupTryCount(lastBackup *backups_core.Ba
return maxFailedTriesCount - len(lastFailedBackups)
}
func (s *BackupsScheduler) cleanOldBackups() error {
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
backupStorePeriod := backupConfig.StorePeriod
if backupStorePeriod == period.PeriodForever {
continue
}
storeDuration := backupStorePeriod.ToDuration()
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
oldBackups, err := s.backupRepository.FindBackupsBeforeDate(
backupConfig.DatabaseID,
dateBeforeBackupsShouldBeDeleted,
)
if err != nil {
s.logger.Error(
"Failed to find old backups for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
continue
}
for _, backup := range oldBackups {
storage, err := s.storageService.GetStorageByID(backup.StorageID)
if err != nil {
s.logger.Error(
"Failed to get storage by ID",
"storageId",
backup.StorageID,
"error",
err,
)
continue
}
encryptor := encryption.GetFieldEncryptor()
err = storage.DeleteFile(encryptor, backup.ID)
if err != nil {
s.logger.Error("Failed to delete backup file", "backupId", backup.ID, "error", err)
}
if err := s.backupRepository.DeleteByID(backup.ID); err != nil {
s.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
continue
}
s.logger.Info(
"Deleted old backup",
"backupId",
backup.ID,
"databaseId",
backupConfig.DatabaseID,
)
}
}
return nil
}
func (s *BackupsScheduler) runPendingBackups() error {
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
@@ -393,7 +370,7 @@ func (s *BackupsScheduler) runPendingBackups() error {
}
func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
nodes, err := s.nodesRegistry.GetAvailableNodes()
nodes, err := s.backupNodesRegistry.GetAvailableNodes()
if err != nil {
return nil, fmt.Errorf("failed to get available nodes: %w", err)
}
@@ -402,7 +379,7 @@ func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
return nil, fmt.Errorf("no nodes available")
}
stats, err := s.nodesRegistry.GetBackupNodesStats()
stats, err := s.backupNodesRegistry.GetBackupNodesStats()
if err != nil {
return nil, fmt.Errorf("failed to get backup nodes stats: %w", err)
}
@@ -415,14 +392,9 @@ func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
var bestNode *BackupNode
var bestScore float64 = -1
now := time.Now().UTC()
for i := range nodes {
node := &nodes[i]
if now.Sub(node.LastHeartbeat) > 2*time.Minute {
continue
}
activeBackups := statsMap[node.ID]
var score float64
@@ -445,16 +417,11 @@ func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
return &bestNode.ID, nil
}
func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUID) {
nodeID, err := uuid.Parse(nodeIDStr)
func (s *BackupsScheduler) onBackupCompleted(nodeID uuid.UUID, backupID uuid.UUID) {
// Verify this task is actually a backup (registry contains multiple task types)
_, err := s.backupRepository.FindByID(backupID)
if err != nil {
s.logger.Error(
"Failed to parse node ID from completion message",
"nodeId",
nodeIDStr,
"error",
err,
)
// Not a backup task, ignore it
return
}
@@ -498,7 +465,7 @@ func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUI
s.backupToNodeRelations[nodeID] = relation
}
if err := s.nodesRegistry.DecrementBackupsInProgress(nodeIDStr); err != nil {
if err := s.backupNodesRegistry.DecrementBackupsInProgress(nodeID); err != nil {
s.logger.Error(
"Failed to decrement backups in progress",
"nodeId",
@@ -512,18 +479,14 @@ func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUI
}
func (s *BackupsScheduler) checkDeadNodesAndFailBackups() error {
nodes, err := s.nodesRegistry.GetAvailableNodes()
nodes, err := s.backupNodesRegistry.GetAvailableNodes()
if err != nil {
return fmt.Errorf("failed to get available nodes: %w", err)
}
aliveNodeIDs := make(map[uuid.UUID]bool)
now := time.Now().UTC()
for _, node := range nodes {
if now.Sub(node.LastHeartbeat) <= 2*time.Minute {
aliveNodeIDs[node.ID] = true
}
aliveNodeIDs[node.ID] = true
}
for nodeID, relation := range s.backupToNodeRelations {
@@ -572,7 +535,7 @@ func (s *BackupsScheduler) checkDeadNodesAndFailBackups() error {
continue
}
if err := s.nodesRegistry.DecrementBackupsInProgress(nodeID.String()); err != nil {
if err := s.backupNodesRegistry.DecrementBackupsInProgress(nodeID); err != nil {
s.logger.Error(
"Failed to decrement backups in progress for dead node",
"nodeId",

View File

@@ -42,8 +42,8 @@ func Test_RunPendingBackups_WhenLastBackupWasYesterday_CreatesNewBackup(t *testi
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
@@ -111,8 +111,8 @@ func Test_RunPendingBackups_WhenLastBackupWasRecentlyCompleted_SkipsBackup(t *te
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
@@ -179,8 +179,8 @@ func Test_RunPendingBackups_WhenLastBackupFailedAndRetriesDisabled_SkipsBackup(t
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
@@ -251,8 +251,8 @@ func Test_RunPendingBackups_WhenLastBackupFailedAndRetriesEnabled_CreatesNewBack
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
@@ -324,8 +324,8 @@ func Test_RunPendingBackups_WhenFailedBackupsExceedMaxRetries_SkipsBackup(t *tes
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
@@ -396,8 +396,8 @@ func Test_RunPendingBackups_WhenBackupsDisabled_SkipsBackup(t *testing.T) {
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
@@ -449,6 +449,8 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
var mockNodeID uuid.UUID
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
@@ -457,9 +459,15 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
// Clean up mock node
if mockNodeID != uuid.Nil {
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: mockNodeID})
}
cache_utils.ClearAllCache()
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
@@ -479,7 +487,7 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
assert.NoError(t, err)
// Register mock node without subscribing to backups (simulates node crash after registration)
mockNodeID := uuid.New()
mockNodeID = uuid.New()
err = CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC())
assert.NoError(t, err)
@@ -493,7 +501,7 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status)
// Verify Valkey counter was incremented when backup was assigned
stats, err := nodesRegistry.GetBackupNodesStats()
stats, err := backupNodesRegistry.GetBackupNodesStats()
assert.NoError(t, err)
foundStat := false
for _, stat := range stats {
@@ -523,7 +531,7 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
assert.Contains(t, *backups[0].FailMessage, "node unavailability")
// Verify Valkey counter was decremented after backup failed
stats, err = nodesRegistry.GetBackupNodesStats()
stats, err = backupNodesRegistry.GetBackupNodesStats()
assert.NoError(t, err)
for _, stat := range stats {
if stat.ID == mockNodeID {
@@ -531,11 +539,109 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
}
}
// Node info should still exist in registry (not removed by checkDeadNodesAndFailBackups)
node, err := GetNodeFromRegistry(mockNodeID)
time.Sleep(200 * time.Millisecond)
}
func Test_OnBackupCompleted_WhenTaskIsNotBackup_SkipsProcessing(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
var mockNodeID uuid.UUID
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
// Clean up mock node
if mockNodeID != uuid.Nil {
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: mockNodeID})
}
cache_utils.ClearAllCache()
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
assert.NotNil(t, node)
assert.Equal(t, mockNodeID, node.ID)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Register mock node
mockNodeID = uuid.New()
err = CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC())
assert.NoError(t, err)
// Start a backup and assign it to the node
GetBackupsScheduler().StartBackup(database.ID, false)
time.Sleep(100 * time.Millisecond)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status)
// Get initial state of the registry
initialStats, err := backupNodesRegistry.GetBackupNodesStats()
assert.NoError(t, err)
var initialActiveTasks int
for _, stat := range initialStats {
if stat.ID == mockNodeID {
initialActiveTasks = stat.ActiveBackups
break
}
}
assert.Equal(t, 1, initialActiveTasks, "Should have 1 active task")
// Call onBackupCompleted with a random UUID (not a backup ID)
nonBackupTaskID := uuid.New()
GetBackupsScheduler().onBackupCompleted(mockNodeID, nonBackupTaskID)
time.Sleep(100 * time.Millisecond)
// Verify: Active tasks counter should remain the same (not decremented)
stats, err := backupNodesRegistry.GetBackupNodesStats()
assert.NoError(t, err)
for _, stat := range stats {
if stat.ID == mockNodeID {
assert.Equal(t, initialActiveTasks, stat.ActiveBackups,
"Active tasks should not change for non-backup task")
}
}
// Verify: backup should still be in progress (not modified)
backups, err = backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status,
"Backup status should not change for non-backup task completion")
// Verify: backupToNodeRelations should still contain the node
scheduler := GetBackupsScheduler()
_, exists := scheduler.backupToNodeRelations[mockNodeID]
assert.True(t, exists, "Node should still be in backupToNodeRelations")
time.Sleep(200 * time.Millisecond)
}
@@ -549,6 +655,14 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
node3ID := uuid.New()
now := time.Now().UTC()
defer func() {
// Clean up all mock nodes
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node1ID})
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node2ID})
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node3ID})
cache_utils.ClearAllCache()
}()
err := CreateMockNodeInRegistry(node1ID, 100, now)
assert.NoError(t, err)
err = CreateMockNodeInRegistry(node2ID, 100, now)
@@ -557,17 +671,17 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
assert.NoError(t, err)
for range 5 {
err = nodesRegistry.IncrementBackupsInProgress(node1ID.String())
err = backupNodesRegistry.IncrementBackupsInProgress(node1ID)
assert.NoError(t, err)
}
for range 2 {
err = nodesRegistry.IncrementBackupsInProgress(node2ID.String())
err = backupNodesRegistry.IncrementBackupsInProgress(node2ID)
assert.NoError(t, err)
}
for range 8 {
err = nodesRegistry.IncrementBackupsInProgress(node3ID.String())
err = backupNodesRegistry.IncrementBackupsInProgress(node3ID)
assert.NoError(t, err)
}
@@ -584,17 +698,24 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
node50MBsID := uuid.New()
now := time.Now().UTC()
defer func() {
// Clean up all mock nodes
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node100MBsID})
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node50MBsID})
cache_utils.ClearAllCache()
}()
err := CreateMockNodeInRegistry(node100MBsID, 100, now)
assert.NoError(t, err)
err = CreateMockNodeInRegistry(node50MBsID, 50, now)
assert.NoError(t, err)
for range 10 {
err = nodesRegistry.IncrementBackupsInProgress(node100MBsID.String())
err = backupNodesRegistry.IncrementBackupsInProgress(node100MBsID)
assert.NoError(t, err)
}
err = nodesRegistry.IncrementBackupsInProgress(node50MBsID.String())
err = backupNodesRegistry.IncrementBackupsInProgress(node50MBsID)
assert.NoError(t, err)
leastBusyNodeID, err := GetBackupsScheduler().calculateLeastBusyNode()
@@ -622,9 +743,11 @@ func Test_FailBackupsInProgress_WhenSchedulerStarts_CancelsBackupsAndUpdatesStat
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
cache_utils.ClearAllCache()
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
@@ -707,3 +830,492 @@ func Test_FailBackupsInProgress_WhenSchedulerStarts_CancelsBackupsAndUpdatesStat
time.Sleep(200 * time.Millisecond)
}
func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T) {
cache_utils.ClearAllCache()
// Start scheduler so it can handle task completions
scheduler := CreateTestScheduler()
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Get initial active task count
stats, err := backupNodesRegistry.GetBackupNodesStats()
assert.NoError(t, err)
var initialActiveTasks int
for _, stat := range stats {
if stat.ID == backuperNode.nodeID {
initialActiveTasks = stat.ActiveBackups
break
}
}
t.Logf("Initial active tasks: %d", initialActiveTasks)
// Start backup
scheduler.StartBackup(database.ID, false)
// Wait for backup to complete
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
// Verify backup was created and completed
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
assert.Equal(t, backups_core.BackupStatusCompleted, backups[0].Status)
// Wait for active task count to decrease
decreased := WaitForActiveTasksDecrease(
t,
backuperNode.nodeID,
initialActiveTasks+1,
10*time.Second,
)
assert.True(t, decreased, "Active task count should have decreased after backup completion")
// Verify final active task count equals initial count
finalStats, err := backupNodesRegistry.GetBackupNodesStats()
assert.NoError(t, err)
for _, stat := range finalStats {
if stat.ID == backuperNode.nodeID {
t.Logf("Final active tasks: %d", stat.ActiveBackups)
assert.Equal(t, initialActiveTasks, stat.ActiveBackups,
"Active task count should return to initial value after backup completion")
break
}
}
time.Sleep(200 * time.Millisecond)
}
func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
cache_utils.ClearAllCache()
// Start scheduler so it can handle task completions
scheduler := CreateTestScheduler()
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Set wrong password to cause backup failure
// We need to bypass service layer validation which would fail on connection test
database.Postgresql.Password = "intentionally_wrong_password"
dbRepo := &databases.DatabaseRepository{}
_, err := dbRepo.Save(database)
assert.NoError(t, err)
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Get initial active task count
stats, err := backupNodesRegistry.GetBackupNodesStats()
assert.NoError(t, err)
var initialActiveTasks int
for _, stat := range stats {
if stat.ID == backuperNode.nodeID {
initialActiveTasks = stat.ActiveBackups
break
}
}
t.Logf("Initial active tasks: %d", initialActiveTasks)
// Start backup
scheduler.StartBackup(database.ID, false)
// Wait for backup to fail
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
// Verify backup was created and failed
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
assert.Equal(t, backups_core.BackupStatusFailed, backups[0].Status)
assert.NotNil(t, backups[0].FailMessage)
if backups[0].FailMessage != nil {
t.Logf("Backup failed with message: %s", *backups[0].FailMessage)
}
// Wait for active task count to decrease
decreased := WaitForActiveTasksDecrease(
t,
backuperNode.nodeID,
initialActiveTasks+1,
10*time.Second,
)
assert.True(t, decreased, "Active task count should have decreased after backup failure")
// Verify final active task count equals initial count
finalStats, err := backupNodesRegistry.GetBackupNodesStats()
assert.NoError(t, err)
for _, stat := range finalStats {
if stat.ID == backuperNode.nodeID {
t.Logf("Final active tasks: %d", stat.ActiveBackups)
assert.Equal(t, initialActiveTasks, stat.ActiveBackups,
"Active task count should return to initial value after backup failure")
break
}
}
time.Sleep(200 * time.Millisecond)
}
func Test_StartBackup_WhenBackupAlreadyInProgress_SkipsNewBackup(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Create an in-progress backup manually
inProgressBackup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
BackupSizeMb: 0,
CreatedAt: time.Now().UTC(),
}
err = backupRepository.Save(inProgressBackup)
assert.NoError(t, err)
// Try to start a new backup - should be skipped
GetBackupsScheduler().StartBackup(database.ID, false)
time.Sleep(200 * time.Millisecond)
// Verify only 1 backup exists (the original in-progress one)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status)
assert.Equal(t, inProgressBackup.ID, backups[0].ID)
time.Sleep(200 * time.Millisecond)
}
func Test_RunPendingBackups_WhenLastBackupFailedWithIsSkipRetry_SkipsBackupEvenWithRetriesEnabled(
t *testing.T,
) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Enable backups with retries enabled and high retry count
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
backupConfig.IsRetryIfFailed = true
backupConfig.MaxFailedTriesCount = 5
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Create a failed backup with IsSkipRetry set to true
failMessage := "backup failed due to size limit exceeded"
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusFailed,
FailMessage: &failMessage,
IsSkipRetry: true,
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
})
// Verify GetRemainedBackupTryCount returns 0 even though retries are enabled
lastBackup, err := backupRepository.FindLastByDatabaseID(database.ID)
assert.NoError(t, err)
assert.NotNil(t, lastBackup)
remainedTries := GetBackupsScheduler().GetRemainedBackupTryCount(lastBackup)
assert.Equal(t, 0, remainedTries, "Should return 0 tries when IsSkipRetry is true")
// Run the scheduler
GetBackupsScheduler().runPendingBackups()
time.Sleep(100 * time.Millisecond)
// Verify no new backup was created (still only 1 backup exists)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1, "No retry should be attempted when IsSkipRetry is true")
time.Sleep(200 * time.Millisecond)
}
func Test_StartBackup_When2BackupsStartedForDifferentDatabases_BothUseCasesAreCalled(t *testing.T) {
cache_utils.ClearAllCache()
// Create mock tracking use case
mockUseCase := NewMockTrackingBackupUsecase()
// Create BackuperNode with mock use case
backuperNode := CreateTestBackuperNodeWithUseCase(mockUseCase)
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
// Create scheduler
scheduler := CreateTestScheduler()
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
// Setup test data
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
// Create 2 separate databases
database1 := databases.CreateTestDatabase(workspace.ID, storage, notifier)
database2 := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
// Cleanup backups for database1
backups1, _ := backupRepository.FindByDatabaseID(database1.ID)
for _, backup := range backups1 {
backupRepository.DeleteByID(backup.ID)
}
// Cleanup backups for database2
backups2, _ := backupRepository.FindByDatabaseID(database2.ID)
for _, backup := range backups2 {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database1)
databases.RemoveTestDatabase(database2)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Enable backups for database1
backupConfig1, err := backups_config.GetBackupConfigService().
GetBackupConfigByDbId(database1.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig1.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig1.IsBackupsEnabled = true
backupConfig1.StorePeriod = period.PeriodWeek
backupConfig1.Storage = storage
backupConfig1.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig1)
assert.NoError(t, err)
// Enable backups for database2
backupConfig2, err := backups_config.GetBackupConfigService().
GetBackupConfigByDbId(database2.ID)
assert.NoError(t, err)
backupConfig2.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig2.IsBackupsEnabled = true
backupConfig2.StorePeriod = period.PeriodWeek
backupConfig2.Storage = storage
backupConfig2.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig2)
assert.NoError(t, err)
// Start 2 backups simultaneously
t.Log("Starting backup for database1")
scheduler.StartBackup(database1.ID, false)
t.Log("Starting backup for database2")
scheduler.StartBackup(database2.ID, false)
// Wait up to 10 seconds for both backups to complete
t.Log("Waiting for both backups to complete...")
success := assert.Eventually(t, func() bool {
callCount := mockUseCase.GetCallCount()
t.Logf("Current call count: %d/2", callCount)
return callCount == 2
}, 10*time.Second, 200*time.Millisecond, "Both use cases should be called within 10 seconds")
if !success {
t.Logf("Test failed: Only %d out of 2 use cases were called", mockUseCase.GetCallCount())
}
// Verify both backup IDs were received
calledBackupIDs := mockUseCase.GetCalledBackupIDs()
t.Logf("Called backup IDs: %v", calledBackupIDs)
assert.Len(t, calledBackupIDs, 2, "Both backup IDs should be tracked")
// Verify both backups exist in repository and are completed
backups1, err := backupRepository.FindByDatabaseID(database1.ID)
assert.NoError(t, err)
assert.Len(t, backups1, 1, "Database1 should have 1 backup")
if len(backups1) > 0 {
t.Logf("Database1 backup status: %s", backups1[0].Status)
}
backups2, err := backupRepository.FindByDatabaseID(database2.ID)
assert.NoError(t, err)
assert.Len(t, backups2, 1, "Database2 should have 1 backup")
if len(backups2) > 0 {
t.Logf("Database2 backup status: %s", backups2[0].Status)
}
// Verify both backups completed successfully
if len(backups1) > 0 {
assert.Equal(t, backups_core.BackupStatusCompleted, backups1[0].Status,
"Database1 backup should be completed")
}
if len(backups2) > 0 {
assert.Equal(t, backups_core.BackupStatusCompleted, backups2[0].Status,
"Database2 backup should be completed")
}
time.Sleep(200 * time.Millisecond)
}

View File

@@ -3,6 +3,8 @@ package backuping
import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
@@ -35,19 +37,56 @@ func CreateTestRouter() *gin.Engine {
func CreateTestBackuperNode() *BackuperNode {
return &BackuperNode{
databases.GetDatabaseService(),
encryption.GetFieldEncryptor(),
workspaces_services.GetWorkspaceService(),
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
notifiers.GetNotifierService(),
backupCancelManager,
nodesRegistry,
logger.GetLogger(),
usecases.GetCreateBackupUsecase(),
uuid.New(),
time.Time{},
databaseService: databases.GetDatabaseService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
workspaceService: workspaces_services.GetWorkspaceService(),
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
notificationSender: notifiers.GetNotifierService(),
backupCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
logger: logger.GetLogger(),
createBackupUseCase: usecases.GetCreateBackupUsecase(),
nodeID: uuid.New(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}
func CreateTestBackuperNodeWithUseCase(useCase backups_core.CreateBackupUsecase) *BackuperNode {
return &BackuperNode{
databaseService: databases.GetDatabaseService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
workspaceService: workspaces_services.GetWorkspaceService(),
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
notificationSender: notifiers.GetNotifierService(),
backupCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
logger: logger.GetLogger(),
createBackupUseCase: useCase,
nodeID: uuid.New(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}
func CreateTestScheduler() *BackupsScheduler {
return &BackupsScheduler{
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
taskCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
lastBackupTime: time.Now().UTC(),
logger: logger.GetLogger(),
backupToNodeRelations: make(map[uuid.UUID]BackupToNodeRelation),
backuperNode: CreateTestBackuperNode(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}
@@ -113,7 +152,7 @@ func StartBackuperNodeForTest(t *testing.T, backuperNode *BackuperNode) context.
// Poll registry for node presence instead of fixed sleep
deadline := time.Now().UTC().Add(5 * time.Second)
for time.Now().UTC().Before(deadline) {
nodes, err := nodesRegistry.GetAvailableNodes()
nodes, err := backupNodesRegistry.GetAvailableNodes()
if err == nil {
for _, node := range nodes {
if node.ID == backuperNode.nodeID {
@@ -138,6 +177,34 @@ func StartBackuperNodeForTest(t *testing.T, backuperNode *BackuperNode) context.
return nil
}
// StartSchedulerForTest starts the BackupsScheduler in a goroutine for testing.
// The scheduler subscribes to task completions and manages backup lifecycle.
// Returns a context cancel function that should be deferred to stop the scheduler.
func StartSchedulerForTest(t *testing.T, scheduler *BackupsScheduler) context.CancelFunc {
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
scheduler.Run(ctx)
close(done)
}()
// Give scheduler time to subscribe to completions
time.Sleep(100 * time.Millisecond)
t.Log("BackupsScheduler started")
return func() {
cancel()
select {
case <-done:
t.Log("BackupsScheduler stopped gracefully")
case <-time.After(2 * time.Second):
t.Log("BackupsScheduler stop timeout")
}
}
}
// StopBackuperNodeForTest stops the BackuperNode by canceling its context.
// It waits for the node to unregister from the registry.
func StopBackuperNodeForTest(t *testing.T, cancel context.CancelFunc, backuperNode *BackuperNode) {
@@ -146,7 +213,7 @@ func StopBackuperNodeForTest(t *testing.T, cancel context.CancelFunc, backuperNo
// Wait for node to unregister from registry
deadline := time.Now().UTC().Add(2 * time.Second)
for time.Now().UTC().Before(deadline) {
nodes, err := nodesRegistry.GetAvailableNodes()
nodes, err := backupNodesRegistry.GetAvailableNodes()
if err == nil {
found := false
for _, node := range nodes {
@@ -173,7 +240,7 @@ func CreateMockNodeInRegistry(nodeID uuid.UUID, throughputMBs int, lastHeartbeat
LastHeartbeat: lastHeartbeat,
}
return nodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
return backupNodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
}
func UpdateNodeHeartbeatDirectly(
@@ -187,11 +254,11 @@ func UpdateNodeHeartbeatDirectly(
LastHeartbeat: lastHeartbeat,
}
return nodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
return backupNodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
}
func GetNodeFromRegistry(nodeID uuid.UUID) (*BackupNode, error) {
nodes, err := nodesRegistry.GetAvailableNodes()
nodes, err := backupNodesRegistry.GetAvailableNodes()
if err != nil {
return nil, err
}
@@ -204,3 +271,48 @@ func GetNodeFromRegistry(nodeID uuid.UUID) (*BackupNode, error) {
return nil, fmt.Errorf("node not found")
}
// WaitForActiveTasksDecrease waits for the active task count to decrease below the initial count.
// It polls the registry every 500ms until the count decreases or the timeout is reached.
// Returns true if the count decreased, false if timeout was reached.
func WaitForActiveTasksDecrease(
t *testing.T,
nodeID uuid.UUID,
initialCount int,
timeout time.Duration,
) bool {
deadline := time.Now().UTC().Add(timeout)
for time.Now().UTC().Before(deadline) {
stats, err := backupNodesRegistry.GetBackupNodesStats()
if err != nil {
t.Logf("WaitForActiveTasksDecrease: error getting node stats: %v", err)
time.Sleep(500 * time.Millisecond)
continue
}
for _, stat := range stats {
if stat.ID == nodeID {
t.Logf(
"WaitForActiveTasksDecrease: current active tasks = %d (initial = %d)",
stat.ActiveBackups,
initialCount,
)
if stat.ActiveBackups < initialCount {
t.Logf(
"WaitForActiveTasksDecrease: active tasks decreased from %d to %d",
initialCount,
stat.ActiveBackups,
)
return true
}
break
}
}
time.Sleep(500 * time.Millisecond)
}
t.Logf("WaitForActiveTasksDecrease: timeout waiting for active tasks to decrease")
return false
}

View File

@@ -1,75 +0,0 @@
package backups_cancellation
import (
"context"
cache_utils "databasus-backend/internal/util/cache"
"log/slog"
"sync"
"github.com/google/uuid"
)
const backupCancelChannel = "backup:cancel"
type BackupCancelManager struct {
mu sync.RWMutex
cancelFuncs map[uuid.UUID]context.CancelFunc
pubsub *cache_utils.PubSubManager
logger *slog.Logger
}
func (m *BackupCancelManager) StartSubscription() {
ctx := context.Background()
handler := func(message string) {
backupID, err := uuid.Parse(message)
if err != nil {
m.logger.Error("Invalid backup ID in cancel message", "message", message, "error", err)
return
}
m.mu.Lock()
defer m.mu.Unlock()
cancelFunc, exists := m.cancelFuncs[backupID]
if exists {
cancelFunc()
delete(m.cancelFuncs, backupID)
m.logger.Info("Cancelled backup via Pub/Sub", "backupID", backupID)
}
}
err := m.pubsub.Subscribe(ctx, backupCancelChannel, handler)
if err != nil {
m.logger.Error("Failed to subscribe to backup cancel channel", "error", err)
} else {
m.logger.Info("Successfully subscribed to backup cancel channel")
}
}
func (m *BackupCancelManager) RegisterBackup(backupID uuid.UUID, cancelFunc context.CancelFunc) {
m.mu.Lock()
defer m.mu.Unlock()
m.cancelFuncs[backupID] = cancelFunc
m.logger.Debug("Registered backup", "backupID", backupID)
}
func (m *BackupCancelManager) CancelBackup(backupID uuid.UUID) error {
ctx := context.Background()
err := m.pubsub.Publish(ctx, backupCancelChannel, backupID.String())
if err != nil {
m.logger.Error("Failed to publish cancel message", "backupID", backupID, "error", err)
return err
}
m.logger.Info("Published backup cancel message", "backupID", backupID)
return nil
}
func (m *BackupCancelManager) UnregisterBackup(backupID uuid.UUID) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.cancelFuncs, backupID)
m.logger.Debug("Unregistered backup", "backupID", backupID)
}

View File

@@ -1,25 +0,0 @@
package backups_cancellation
import (
"context"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/logger"
"sync"
"github.com/google/uuid"
)
var backupCancelManager = &BackupCancelManager{
sync.RWMutex{},
make(map[uuid.UUID]context.CancelFunc),
cache_utils.NewPubSubManager(),
logger.GetLogger(),
}
func GetBackupCancelManager() *BackupCancelManager {
return backupCancelManager
}
func SetupDependencies() {
backupCancelManager.StartSubscription()
}

View File

@@ -913,7 +913,7 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
assert.NoError(t, err)
// Register a cancellable context for the backup
GetBackupService().backupCancelManager.RegisterBackup(backup.ID, func() {})
GetBackupService().taskCancelManager.RegisterTask(backup.ID, func() {})
resp := test_utils.MakePostRequest(
t,
@@ -1156,7 +1156,7 @@ func createTestDatabase(
Type: databases.DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",

View File

@@ -15,6 +15,7 @@ type Backup struct {
Status BackupStatus `json:"status" gorm:"column:status;not null"`
FailMessage *string `json:"failMessage" gorm:"column:fail_message"`
IsSkipRetry bool `json:"isSkipRetry" gorm:"column:is_skip_retry;type:boolean;not null"`
BackupSizeMb float64 `json:"backupSizeMb" gorm:"column:backup_size_mb;default:0"`

View File

@@ -212,3 +212,36 @@ func (r *BackupRepository) CountByDatabaseID(databaseID uuid.UUID) (int64, error
return count, nil
}
func (r *BackupRepository) GetTotalSizeByDatabase(databaseID uuid.UUID) (float64, error) {
var totalSize float64
if err := storage.
GetDb().
Model(&Backup{}).
Select("COALESCE(SUM(backup_size_mb), 0)").
Where("database_id = ? AND status != ?", databaseID, BackupStatusInProgress).
Scan(&totalSize).Error; err != nil {
return 0, err
}
return totalSize, nil
}
func (r *BackupRepository) FindOldestByDatabaseExcludingInProgress(
databaseID uuid.UUID,
limit int,
) ([]*Backup, error) {
var backups []*Backup
if err := storage.
GetDb().
Where("database_id = ? AND status != ?", databaseID, BackupStatusInProgress).
Order("created_at ASC").
Limit(limit).
Find(&backups).Error; err != nil {
return nil, err
}
return backups, nil
}

View File

@@ -1,9 +1,11 @@
package backups
import (
"sync"
"sync/atomic"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups/backuping"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_download "databasus-backend/internal/features/backups/backups/download"
"databasus-backend/internal/features/backups/backups/usecases"
@@ -12,6 +14,7 @@ import (
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
@@ -19,25 +22,26 @@ import (
var backupRepository = &backups_core.BackupRepository{}
var backupCancelManager = backups_cancellation.GetBackupCancelManager()
var taskCancelManager = task_cancellation.GetTaskCancelManager()
var backupService = &BackupService{
databaseService: databases.GetDatabaseService(),
storageService: storages.GetStorageService(),
backupRepository: backupRepository,
notifierService: notifiers.GetNotifierService(),
notificationSender: notifiers.GetNotifierService(),
backupConfigService: backups_config.GetBackupConfigService(),
secretKeyService: encryption_secrets.GetSecretKeyService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
createBackupUseCase: usecases.GetCreateBackupUsecase(),
logger: logger.GetLogger(),
backupRemoveListeners: []backups_core.BackupRemoveListener{},
workspaceService: workspaces_services.GetWorkspaceService(),
auditLogService: audit_logs.GetAuditLogService(),
backupCancelManager: backupCancelManager,
downloadTokenService: backups_download.GetDownloadTokenService(),
backupSchedulerService: backuping.GetBackupsScheduler(),
databases.GetDatabaseService(),
storages.GetStorageService(),
backupRepository,
notifiers.GetNotifierService(),
notifiers.GetNotifierService(),
backups_config.GetBackupConfigService(),
encryption_secrets.GetSecretKeyService(),
encryption.GetFieldEncryptor(),
usecases.GetCreateBackupUsecase(),
logger.GetLogger(),
[]backups_core.BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
taskCancelManager,
backups_download.GetDownloadTokenService(),
backuping.GetBackupsScheduler(),
backuping.GetBackupCleaner(),
}
var backupController = &BackupController{
@@ -52,11 +56,26 @@ func GetBackupController() *BackupController {
return backupController
}
func SetupDependencies() {
backups_config.
GetBackupConfigService().
SetDatabaseStorageChangeListener(backupService)
var (
setupOnce sync.Once
isSetup atomic.Bool
)
databases.GetDatabaseService().AddDbRemoveListener(backupService)
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
func SetupDependencies() {
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
backups_config.
GetBackupConfigService().
SetDatabaseStorageChangeListener(backupService)
databases.GetDatabaseService().AddDbRemoveListener(backupService)
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -2,33 +2,49 @@ package backups_download
import (
"context"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
)
type DownloadTokenBackgroundService struct {
downloadTokenService *DownloadTokenService
logger *slog.Logger
runOnce sync.Once
hasRun atomic.Bool
}
func (s *DownloadTokenBackgroundService) Run(ctx context.Context) {
s.logger.Info("Starting download token cleanup background service")
wasAlreadyRun := s.hasRun.Load()
if ctx.Err() != nil {
return
}
s.runOnce.Do(func() {
s.hasRun.Store(true)
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
s.logger.Info("Starting download token cleanup background service")
for {
select {
case <-ctx.Done():
if ctx.Err() != nil {
return
case <-ticker.C:
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
s.logger.Error("Failed to clean expired download tokens", "error", err)
}
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
s.logger.Error("Failed to clean expired download tokens", "error", err)
}
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
}

View File

@@ -1,6 +1,9 @@
package backups_download
import (
"sync"
"sync/atomic"
"databasus-backend/internal/config"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/logger"
@@ -30,8 +33,10 @@ func init() {
}
downloadTokenBackgroundService = &DownloadTokenBackgroundService{
downloadTokenService,
logger.GetLogger(),
downloadTokenService: downloadTokenService,
logger: logger.GetLogger(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}

View File

@@ -9,7 +9,6 @@ import (
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups/backuping"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_download "databasus-backend/internal/features/backups/backups/download"
"databasus-backend/internal/features/backups/backups/encryption"
@@ -18,6 +17,7 @@ import (
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
users_models "databasus-backend/internal/features/users/models"
workspaces_services "databasus-backend/internal/features/workspaces/services"
util_encryption "databasus-backend/internal/util/encryption"
@@ -43,9 +43,10 @@ type BackupService struct {
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
backupCancelManager *backups_cancellation.BackupCancelManager
taskCancelManager *task_cancellation.TaskCancelManager
downloadTokenService *backups_download.DownloadTokenService
backupSchedulerService *backuping.BackupsScheduler
backupCleaner *backuping.BackupCleaner
}
func (s *BackupService) AddBackupRemoveListener(listener backups_core.BackupRemoveListener) {
@@ -189,7 +190,7 @@ func (s *BackupService) DeleteBackup(
database.WorkspaceID,
)
return s.deleteBackup(backup)
return s.backupCleaner.DeleteBackup(backup)
}
func (s *BackupService) GetBackup(backupID uuid.UUID) (*backups_core.Backup, error) {
@@ -226,7 +227,7 @@ func (s *BackupService) CancelBackup(
return errors.New("backup is not in progress")
}
if err := s.backupCancelManager.CancelBackup(backupID); err != nil {
if err := s.taskCancelManager.CancelTask(backupID); err != nil {
return err
}
@@ -292,29 +293,6 @@ func (s *BackupService) GetBackupFile(
return reader, backup, database, nil
}
func (s *BackupService) deleteBackup(backup *backups_core.Backup) error {
for _, listener := range s.backupRemoveListeners {
if err := listener.OnBeforeBackupRemove(backup); err != nil {
return err
}
}
storage, err := s.storageService.GetStorageByID(backup.StorageID)
if err != nil {
return err
}
err = storage.DeleteFile(s.fieldEncryptor, backup.ID)
if err != nil {
// we do not return error here, because sometimes clean up performed
// before unavailable storage removal or change - therefore we should
// proceed even in case of error
s.logger.Error("Failed to delete backup file", "error", err)
}
return s.backupRepository.DeleteByID(backup.ID)
}
func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
dbBackupsInProgress, err := s.backupRepository.FindByDatabaseIdAndStatus(
databaseID,
@@ -336,7 +314,7 @@ func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
}
for _, dbBackup := range dbBackups {
err := s.deleteBackup(dbBackup)
err := s.backupCleaner.DeleteBackup(dbBackup)
if err != nil {
return err
}

View File

@@ -75,3 +75,23 @@ func WaitForBackupCompletion(
t.Logf("WaitForBackupCompletion: timeout waiting for backup to complete")
}
// CreateTestBackup creates a simple test backup record for testing purposes
func CreateTestBackup(databaseID, storageID uuid.UUID) *backups_core.Backup {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: databaseID,
StorageID: storageID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10.5,
BackupDurationMs: 1000,
CreatedAt: time.Now().UTC(),
}
repo := &backups_core.BackupRepository{}
if err := repo.Save(backup); err != nil {
panic(err)
}
return backup
}

View File

@@ -1462,7 +1462,7 @@ func createTestDatabaseViaAPI(
Type: databases.DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",

View File

@@ -1,10 +1,14 @@
package backups_config
import (
"sync"
"sync/atomic"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/logger"
)
var backupConfigRepository = &BackupConfigRepository{}
@@ -28,6 +32,21 @@ func GetBackupConfigService() *BackupConfigService {
return backupConfigService
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
storages.GetStorageService().SetStorageDatabaseCounter(backupConfigService)
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
storages.GetStorageService().SetStorageDatabaseCounter(backupConfigService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -31,6 +31,11 @@ type BackupConfig struct {
MaxFailedTriesCount int `json:"maxFailedTriesCount" gorm:"column:max_failed_tries_count;type:int;not null"`
Encryption BackupEncryption `json:"encryption" gorm:"column:encryption;type:text;not null;default:'NONE'"`
// MaxBackupSizeMB limits individual backup size. 0 = unlimited.
MaxBackupSizeMB int64 `json:"maxBackupSizeMb" gorm:"column:max_backup_size_mb;type:int;not null"`
// MaxBackupsTotalSizeMB limits total size of all backups. 0 = unlimited.
MaxBackupsTotalSizeMB int64 `json:"maxBackupsTotalSizeMb" gorm:"column:max_backups_total_size_mb;type:int;not null"`
}
func (h *BackupConfig) TableName() string {
@@ -89,20 +94,30 @@ func (b *BackupConfig) Validate() error {
return errors.New("encryption must be NONE or ENCRYPTED")
}
if b.MaxBackupSizeMB < 0 {
return errors.New("max backup size must be non-negative")
}
if b.MaxBackupsTotalSizeMB < 0 {
return errors.New("max backups total size must be non-negative")
}
return nil
}
func (b *BackupConfig) Copy(newDatabaseID uuid.UUID) *BackupConfig {
return &BackupConfig{
DatabaseID: newDatabaseID,
IsBackupsEnabled: b.IsBackupsEnabled,
StorePeriod: b.StorePeriod,
BackupIntervalID: uuid.Nil,
BackupInterval: b.BackupInterval.Copy(),
StorageID: b.StorageID,
SendNotificationsOn: b.SendNotificationsOn,
IsRetryIfFailed: b.IsRetryIfFailed,
MaxFailedTriesCount: b.MaxFailedTriesCount,
Encryption: b.Encryption,
DatabaseID: newDatabaseID,
IsBackupsEnabled: b.IsBackupsEnabled,
StorePeriod: b.StorePeriod,
BackupIntervalID: uuid.Nil,
BackupInterval: b.BackupInterval.Copy(),
StorageID: b.StorageID,
SendNotificationsOn: b.SendNotificationsOn,
IsRetryIfFailed: b.IsRetryIfFailed,
MaxFailedTriesCount: b.MaxFailedTriesCount,
Encryption: b.Encryption,
MaxBackupSizeMB: b.MaxBackupSizeMB,
MaxBackupsTotalSizeMB: b.MaxBackupsTotalSizeMB,
}
}

View File

@@ -44,7 +44,7 @@ func getTestPostgresConfig() *postgresql.PostgresqlDatabase {
testDbName := "testdb"
return &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",
@@ -67,7 +67,7 @@ func getTestMariadbConfig() *mariadb.MariadbDatabase {
testDbName := "testdb"
return &mariadb.MariadbDatabase{
Version: tools.MariadbVersion1011,
Host: "localhost",
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",
@@ -88,7 +88,7 @@ func getTestMongodbConfig() *mongodb.MongodbDatabase {
return &mongodb.MongodbDatabase{
Version: tools.MongodbVersion7,
Host: "localhost",
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "root",
Password: "rootpassword",
@@ -829,7 +829,7 @@ func createTestDatabaseViaAPI(
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",

View File

@@ -515,9 +515,13 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
hasProcess := false
hasAllPrivileges := false
// Escape underscores to match MariaDB's grant output format
// MariaDB escapes _ as \_ in SHOW GRANTS output
// Pattern matches either literal _ or escaped \_
escapedDbName := strings.ReplaceAll(regexp.QuoteMeta(database), "_", `(_|\\_)`)
dbPatternStr := fmt.Sprintf(
`(?i)ON\s+[\x60'"]?%s[\x60'"]?\s*\.\s*\*`,
regexp.QuoteMeta(database),
escapedDbName,
)
dbPattern := regexp.MustCompile(dbPatternStr)
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\s*\.\s*\*`)

View File

@@ -694,6 +694,115 @@ func Test_TestConnection_DatabaseWithUnderscores_Success(t *testing.T) {
assert.NoError(t, err)
}
func Test_TestConnection_DatabaseWithUnderscoresAndAllPrivileges_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MariadbVersion
port string
}{
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMariadbContainer(t, tc.port, tc.version)
defer container.DB.Close()
underscoreDbName := "test_all_db"
_, err := container.DB.Exec(
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
)
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", underscoreDbName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
)
}()
underscoreDSN := fmt.Sprintf(
"%s:%s@tcp(%s:%d)/%s?parseTime=true",
container.Username,
container.Password,
container.Host,
container.Port,
underscoreDbName,
)
underscoreDB, err := sqlx.Connect("mysql", underscoreDSN)
assert.NoError(t, err)
defer underscoreDB.Close()
_, err = underscoreDB.Exec(`
CREATE TABLE all_priv_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)
`)
assert.NoError(t, err)
_, err = underscoreDB.Exec(`INSERT INTO all_priv_test (data) VALUES ('test1')`)
assert.NoError(t, err)
allPrivUsername := fmt.Sprintf("allpriv%s", uuid.New().String()[:8])
allPrivPassword := "allprivpass123"
_, err = underscoreDB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
allPrivUsername,
allPrivPassword,
))
assert.NoError(t, err)
_, err = underscoreDB.Exec(fmt.Sprintf(
"GRANT ALL PRIVILEGES ON `%s`.* TO '%s'@'%%'",
underscoreDbName,
allPrivUsername,
))
assert.NoError(t, err)
_, err = underscoreDB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer dropUserSafe(underscoreDB, allPrivUsername)
mariadbModel := &MariadbDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: allPrivUsername,
Password: allPrivPassword,
Database: &underscoreDbName,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mariadbModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
assert.NotEmpty(t, mariadbModel.Privileges)
assert.Contains(t, mariadbModel.Privileges, "SELECT")
assert.Contains(t, mariadbModel.Privileges, "SHOW VIEW")
})
}
}
type MariadbContainer struct {
Host string
Port int
@@ -714,7 +823,7 @@ func connectToMariadbContainer(
}
dbName := "testdb"
host := "127.0.0.1"
host := config.GetEnv().TestLocalhost
username := "root"
password := "rootpassword"

View File

@@ -9,6 +9,7 @@ import (
"strconv"
"strings"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
@@ -397,7 +398,7 @@ func connectToMongodbContainer(
}
dbName := "testdb"
host := "127.0.0.1"
host := config.GetEnv().TestLocalhost
username := "root"
password := "rootpassword"
authDatabase := "admin"
@@ -406,11 +407,18 @@ func connectToMongodbContainer(
assert.NoError(t, err)
uri := fmt.Sprintf(
"mongodb://%s:%s@%s:%d/%s?authSource=%s",
username, password, host, portInt, dbName, authDatabase,
"mongodb://%s:%s@%s:%d/%s?authSource=%s&serverSelectionTimeoutMS=5000&connectTimeoutMS=5000",
username,
password,
host,
portInt,
dbName,
authDatabase,
)
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
clientOptions := options.Client().ApplyURI(uri)
client, err := mongo.Connect(ctx, clientOptions)
if err != nil {

View File

@@ -400,6 +400,7 @@ func HasPrivilege(privileges, priv string) bool {
func (m *MysqlDatabase) buildDSN(password string, database string) string {
tlsConfig := "false"
allowCleartext := ""
if m.IsHttps {
err := mysql.RegisterTLSConfig("mysql-skip-verify", &tls.Config{
@@ -411,16 +412,18 @@ func (m *MysqlDatabase) buildDSN(password string, database string) string {
}
tlsConfig = "mysql-skip-verify"
allowCleartext = "&allowCleartextPasswords=1"
}
return fmt.Sprintf(
"%s:%s@tcp(%s:%d)/%s?parseTime=true&timeout=15s&tls=%s&charset=utf8mb4",
"%s:%s@tcp(%s:%d)/%s?parseTime=true&timeout=15s&tls=%s&charset=utf8mb4%s",
m.Username,
password,
m.Host,
m.Port,
database,
tlsConfig,
allowCleartext,
)
}
@@ -486,9 +489,13 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
hasProcess := false
hasAllPrivileges := false
// Escape underscores to match MySQL's grant output format
// MySQL escapes _ as \_ in SHOW GRANTS output
// Pattern matches either literal _ or escaped \_
escapedDbName := strings.ReplaceAll(regexp.QuoteMeta(database), "_", `(_|\\_)`)
dbPatternStr := fmt.Sprintf(
`(?i)ON\s+[\x60'"]?%s[\x60'"]?\s*\.\s*\*`,
regexp.QuoteMeta(database),
escapedDbName,
)
dbPattern := regexp.MustCompile(dbPatternStr)
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\s*\.\s*\*`)

View File

@@ -674,6 +674,112 @@ func Test_TestConnection_DatabaseWithUnderscores_Success(t *testing.T) {
assert.NoError(t, err)
}
func Test_TestConnection_DatabaseWithUnderscoresAndAllPrivileges_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MysqlVersion
port string
}{
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port},
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port},
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port},
{"MySQL 9", tools.MysqlVersion9, env.TestMysql90Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMysqlContainer(t, tc.port, tc.version)
defer container.DB.Close()
underscoreDbName := "test_all_db"
_, err := container.DB.Exec(
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
)
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", underscoreDbName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
)
}()
underscoreDSN := fmt.Sprintf(
"%s:%s@tcp(%s:%d)/%s?parseTime=true",
container.Username,
container.Password,
container.Host,
container.Port,
underscoreDbName,
)
underscoreDB, err := sqlx.Connect("mysql", underscoreDSN)
assert.NoError(t, err)
defer underscoreDB.Close()
_, err = underscoreDB.Exec(`
CREATE TABLE all_priv_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)
`)
assert.NoError(t, err)
_, err = underscoreDB.Exec(`INSERT INTO all_priv_test (data) VALUES ('test1')`)
assert.NoError(t, err)
allPrivUsername := fmt.Sprintf("allpriv_%s", uuid.New().String()[:8])
allPrivPassword := "allprivpass123"
_, err = underscoreDB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
allPrivUsername,
allPrivPassword,
))
assert.NoError(t, err)
_, err = underscoreDB.Exec(fmt.Sprintf(
"GRANT ALL PRIVILEGES ON `%s`.* TO '%s'@'%%'",
underscoreDbName,
allPrivUsername,
))
assert.NoError(t, err)
_, err = underscoreDB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(
fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", allPrivUsername),
)
}()
mysqlModel := &MysqlDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: allPrivUsername,
Password: allPrivPassword,
Database: &underscoreDbName,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mysqlModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
assert.NotEmpty(t, mysqlModel.Privileges)
assert.Contains(t, mysqlModel.Privileges, "SELECT")
assert.Contains(t, mysqlModel.Privileges, "SHOW VIEW")
})
}
}
type MysqlContainer struct {
Host string
Port int
@@ -694,7 +800,7 @@ func connectToMysqlContainer(
}
dbName := "testdb"
host := "127.0.0.1"
host := config.GetEnv().TestLocalhost
username := "root"
password := "rootpassword"

View File

@@ -13,6 +13,7 @@ import (
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"gorm.io/gorm"
)
@@ -85,6 +86,42 @@ func (p *PostgresqlDatabase) Validate() error {
return errors.New("cpu count must be greater than 0")
}
// Prevent Databasus from backing up itself
// Databasus runs an internal PostgreSQL instance that should not be backed up through the UI
// because it would expose internal metadata to non-system administrators.
// To properly backup Databasus, see: https://databasus.com/faq#backup-databasus
if p.Database != nil && *p.Database != "" {
localhostHosts := []string{
"localhost",
"127.0.0.1",
"172.17.0.1",
"host.docker.internal",
"::1", // IPv6 loopback (equivalent to 127.0.0.1)
"::", // IPv6 all interfaces (equivalent to 0.0.0.0)
"0.0.0.0", // IPv4 all interfaces
}
isLocalhost := false
for _, host := range localhostHosts {
if strings.EqualFold(p.Host, host) {
isLocalhost = true
break
}
}
// Also check if the host is in the entire 127.0.0.0/8 loopback range
if strings.HasPrefix(p.Host, "127.") {
isLocalhost = true
}
if isLocalhost && strings.EqualFold(*p.Database, "databasus") {
return errors.New(
"backing up Databasus internal database is not allowed. To backup Databasus itself, see https://databasus.com/faq#backup-databasus",
)
}
}
return nil
}
@@ -358,10 +395,13 @@ func (p *PostgresqlDatabase) IsUserReadOnly(
//
// This method performs the following operations atomically in a single transaction:
// 1. Creates a PostgreSQL user with a UUID-based password
// 2. Grants CONNECT privilege on the database
// 3. Grants USAGE on all non-system schemas
// 4. Grants SELECT on all existing tables and sequences
// 5. Sets default privileges for future tables and sequences
// 2. Revokes CREATE privilege on public schema from PUBLIC role
// 3. Grants CONNECT privilege on the database
// 4. Discovers all user-created schemas
// 5. Grants USAGE on all non-system schemas
// 6. Grants SELECT on all existing tables and sequences
// 7. Sets default privileges for future tables and sequences
// 8. Verifies user creation before committing
//
// Security features:
// - Username format: "databasus-{8-char-uuid}" for uniqueness
@@ -451,33 +491,56 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
return "", "", fmt.Errorf("failed to create user: %w", err)
}
// Step 1.5: Revoke CREATE privilege from PUBLIC role on public schema
// Step 2: Check if public schema exists and revoke CREATE privilege if it does
// This is necessary because all PostgreSQL users inherit CREATE privilege on the
// public schema through the PUBLIC role. This is a one-time operation that affects
// the entire database, making it more secure by default.
// Note: This only affects the public schema; other schemas are unaffected.
_, err = tx.Exec(ctx, `REVOKE CREATE ON SCHEMA public FROM PUBLIC`)
if err != nil {
logger.Error("Failed to revoke CREATE on public from PUBLIC", "error", err)
if !strings.Contains(err.Error(), "schema \"public\" does not exist") &&
!strings.Contains(err.Error(), "permission denied") {
return "", "", fmt.Errorf("failed to revoke CREATE from PUBLIC: %w", err)
}
}
// Now revoke from the specific user as well (belt and suspenders)
_, err = tx.Exec(ctx, fmt.Sprintf(`REVOKE CREATE ON SCHEMA public FROM "%s"`, baseUsername))
if err != nil {
logger.Error(
"Failed to revoke CREATE on public schema from user",
"error",
err,
"username",
baseUsername,
var publicSchemaExists bool
err = tx.QueryRow(ctx, `
SELECT EXISTS(
SELECT 1 FROM information_schema.schemata
WHERE schema_name = 'public'
)
`).Scan(&publicSchemaExists)
if err != nil {
return "", "", fmt.Errorf("failed to check if public schema exists: %w", err)
}
// Step 2: Grant database connection privilege and revoke TEMP
if publicSchemaExists {
// Revoke CREATE from PUBLIC role (affects all users)
_, err = tx.Exec(ctx, `REVOKE CREATE ON SCHEMA public FROM PUBLIC`)
if err != nil {
if strings.Contains(err.Error(), "permission denied") {
logger.Warn(
"Failed to revoke CREATE on public from PUBLIC (permission denied)",
"error",
err,
)
} else {
return "", "", fmt.Errorf("failed to revoke CREATE from PUBLIC on existing public schema: %w", err)
}
}
// Now revoke from the specific user as well (belt and suspenders)
_, err = tx.Exec(
ctx,
fmt.Sprintf(`REVOKE CREATE ON SCHEMA public FROM "%s"`, baseUsername),
)
if err != nil {
logger.Warn(
"Failed to revoke CREATE on public schema from user",
"error",
err,
"username",
baseUsername,
)
}
} else {
logger.Info("Public schema does not exist, skipping CREATE privilege revocation")
}
// Step 3: Grant database connection privilege and revoke TEMP
_, err = tx.Exec(
ctx,
fmt.Sprintf(`GRANT CONNECT ON DATABASE "%s" TO "%s"`, *p.Database, baseUsername),
@@ -501,7 +564,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
logger.Warn("Failed to revoke TEMP privilege", "error", err, "username", baseUsername)
}
// Step 3: Discover all user-created schemas
// Step 4: Discover all user-created schemas
rows, err := tx.Query(ctx, `
SELECT schema_name
FROM information_schema.schemata
@@ -526,7 +589,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
return "", "", fmt.Errorf("error iterating schemas: %w", err)
}
// Step 4: Grant USAGE on each schema and explicitly prevent CREATE
// Step 5: Grant USAGE on each schema and explicitly prevent CREATE
for _, schema := range schemas {
// Revoke CREATE specifically (handles inheritance from PUBLIC role)
_, err = tx.Exec(
@@ -555,7 +618,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
}
}
// Step 5: Grant SELECT on ALL existing tables and sequences
// Step 6: Grant SELECT on ALL existing tables and sequences
grantSelectSQL := fmt.Sprintf(`
DO $$
DECLARE
@@ -577,7 +640,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
return "", "", fmt.Errorf("failed to grant select on tables: %w", err)
}
// Step 6: Set default privileges for FUTURE tables and sequences
// Step 7: Set default privileges for FUTURE tables and sequences
defaultPrivilegesSQL := fmt.Sprintf(`
DO $$
DECLARE
@@ -599,7 +662,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
return "", "", fmt.Errorf("failed to set default privileges: %w", err)
}
// Step 7: Verify user creation before committing
// Step 8: Verify user creation before committing
var verifyUsername string
err = tx.QueryRow(ctx, fmt.Sprintf(`SELECT rolname FROM pg_roles WHERE rolname = '%s'`, baseUsername)).
Scan(&verifyUsername)
@@ -815,7 +878,15 @@ func checkBackupPermissions(
}
if err != nil {
return fmt.Errorf("cannot check SELECT privileges: %w", err)
// If the user doesn't have USAGE on the schema, has_table_privilege will fail
// with "permission denied for schema". This means they definitely don't have
// SELECT privileges, so treat this as missing permissions rather than an error.
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) && pgErr.Code == "42501" { // insufficient_privilege
selectableTableCount = 0
} else {
return fmt.Errorf("cannot check SELECT privileges: %w", err)
}
}
if selectableTableCount == 0 {
missingPrivileges = append(missingPrivileges, "SELECT on tables")

View File

@@ -599,6 +599,10 @@ func Test_CreateReadOnlyUser_DatabaseNameWithDash_Success(t *testing.T) {
}
func Test_CreateReadOnlyUser_Supabase_UserCanReadButNotWrite(t *testing.T) {
if config.GetEnv().IsSkipExternalResourcesTests {
t.Skip("Skipping Supabase test: IS_SKIP_EXTERNAL_RESOURCES_TESTS is true")
}
env := config.GetEnv()
if env.TestSupabaseHost == "" {
@@ -705,6 +709,607 @@ func Test_CreateReadOnlyUser_Supabase_UserCanReadButNotWrite(t *testing.T) {
assert.Contains(t, err.Error(), "permission denied")
}
func Test_CreateReadOnlyUser_WithPublicSchema_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version string
port string
}{
{"PostgreSQL 12", "12", env.TestPostgres12Port},
{"PostgreSQL 13", "13", env.TestPostgres13Port},
{"PostgreSQL 14", "14", env.TestPostgres14Port},
{"PostgreSQL 15", "15", env.TestPostgres15Port},
{"PostgreSQL 16", "16", env.TestPostgres16Port},
{"PostgreSQL 17", "17", env.TestPostgres17Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToPostgresContainer(t, tc.port)
defer container.DB.Close()
_, err := container.DB.Exec(`
DROP TABLE IF EXISTS public_schema_test CASCADE;
CREATE TABLE public_schema_test (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO public_schema_test (data) VALUES ('test1'), ('test2');
`)
assert.NoError(t, err)
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.NotEmpty(t, username)
assert.NotEmpty(t, password)
assert.True(t, strings.HasPrefix(username, "databasus-"))
readOnlyModel := &PostgresqlDatabase{
Version: pgModel.Version,
Host: pgModel.Host,
Port: pgModel.Port,
Username: username,
Password: password,
Database: pgModel.Database,
IsHttps: false,
}
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
ctx,
logger,
nil,
uuid.New(),
)
assert.NoError(t, err)
assert.True(t, isReadOnly, "User should be read-only")
assert.Empty(t, privileges, "Read-only user should have no write privileges")
readOnlyDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
username,
password,
container.Database,
)
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
var count int
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM public_schema_test")
assert.NoError(t, err)
assert.Equal(t, 2, count)
_, err = readOnlyConn.Exec(
"INSERT INTO public_schema_test (data) VALUES ('should-fail')",
)
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = readOnlyConn.Exec("CREATE TABLE public.hack_table (id INT)")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
if err != nil {
t.Logf("Warning: Failed to drop owned objects: %v", err)
}
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
assert.NoError(t, err)
})
}
}
func Test_CreateReadOnlyUser_WithoutPublicSchema_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version string
port string
}{
{"PostgreSQL 12", "12", env.TestPostgres12Port},
{"PostgreSQL 13", "13", env.TestPostgres13Port},
{"PostgreSQL 14", "14", env.TestPostgres14Port},
{"PostgreSQL 15", "15", env.TestPostgres15Port},
{"PostgreSQL 16", "16", env.TestPostgres16Port},
{"PostgreSQL 17", "17", env.TestPostgres17Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToPostgresContainer(t, tc.port)
defer container.DB.Close()
_, err := container.DB.Exec(`
DROP SCHEMA IF EXISTS public CASCADE;
DROP SCHEMA IF EXISTS app_schema CASCADE;
DROP SCHEMA IF EXISTS data_schema CASCADE;
CREATE SCHEMA app_schema;
CREATE SCHEMA data_schema;
CREATE TABLE app_schema.users (
id SERIAL PRIMARY KEY,
username TEXT NOT NULL
);
CREATE TABLE data_schema.records (
id SERIAL PRIMARY KEY,
info TEXT NOT NULL
);
INSERT INTO app_schema.users (username) VALUES ('user1'), ('user2');
INSERT INTO data_schema.records (info) VALUES ('record1'), ('record2');
`)
assert.NoError(t, err)
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err, "CreateReadOnlyUser should succeed without public schema")
assert.NotEmpty(t, username)
assert.NotEmpty(t, password)
assert.True(t, strings.HasPrefix(username, "databasus-"))
readOnlyModel := &PostgresqlDatabase{
Version: pgModel.Version,
Host: pgModel.Host,
Port: pgModel.Port,
Username: username,
Password: password,
Database: pgModel.Database,
IsHttps: false,
}
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
ctx,
logger,
nil,
uuid.New(),
)
assert.NoError(t, err)
assert.True(t, isReadOnly, "User should be read-only")
assert.Empty(t, privileges, "Read-only user should have no write privileges")
readOnlyDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
username,
password,
container.Database,
)
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
var userCount int
err = readOnlyConn.Get(&userCount, "SELECT COUNT(*) FROM app_schema.users")
assert.NoError(t, err)
assert.Equal(t, 2, userCount)
var recordCount int
err = readOnlyConn.Get(&recordCount, "SELECT COUNT(*) FROM data_schema.records")
assert.NoError(t, err)
assert.Equal(t, 2, recordCount)
_, err = readOnlyConn.Exec(
"INSERT INTO app_schema.users (username) VALUES ('should-fail')",
)
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = readOnlyConn.Exec("CREATE TABLE app_schema.hack_table (id INT)")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = readOnlyConn.Exec("CREATE TABLE data_schema.hack_table (id INT)")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
if err != nil {
t.Logf("Warning: Failed to drop owned objects: %v", err)
}
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
assert.NoError(t, err)
_, err = container.DB.Exec(`
DROP SCHEMA IF EXISTS app_schema CASCADE;
DROP SCHEMA IF EXISTS data_schema CASCADE;
CREATE SCHEMA IF NOT EXISTS public;
`)
assert.NoError(t, err)
})
}
}
func Test_CreateReadOnlyUser_PublicSchemaExistsButNoPermissions_ReturnsError(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version string
port string
}{
{"PostgreSQL 12", "12", env.TestPostgres12Port},
{"PostgreSQL 13", "13", env.TestPostgres13Port},
{"PostgreSQL 14", "14", env.TestPostgres14Port},
{"PostgreSQL 15", "15", env.TestPostgres15Port},
{"PostgreSQL 16", "16", env.TestPostgres16Port},
{"PostgreSQL 17", "17", env.TestPostgres17Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToPostgresContainer(t, tc.port)
defer container.DB.Close()
limitedAdminUsername := fmt.Sprintf("limited_admin_%s", uuid.New().String()[:8])
limitedAdminPassword := "limited_password_123"
_, err := container.DB.Exec(`
CREATE SCHEMA IF NOT EXISTS public;
DROP TABLE IF EXISTS public.permission_test_table CASCADE;
CREATE TABLE public.permission_test_table (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO public.permission_test_table (data) VALUES ('test1');
`)
assert.NoError(t, err)
_, err = container.DB.Exec(`GRANT CREATE ON SCHEMA public TO PUBLIC`)
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN CREATEROLE`,
limitedAdminUsername,
limitedAdminPassword,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
container.Database,
limitedAdminUsername,
))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(
fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, limitedAdminUsername),
)
_, _ = container.DB.Exec(
fmt.Sprintf(`DROP USER IF EXISTS "%s"`, limitedAdminUsername),
)
_, _ = container.DB.Exec(`REVOKE CREATE ON SCHEMA public FROM PUBLIC`)
}()
limitedAdminDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
limitedAdminUsername,
limitedAdminPassword,
container.Database,
)
limitedAdminConn, err := sqlx.Connect("postgres", limitedAdminDSN)
assert.NoError(t, err)
defer limitedAdminConn.Close()
pgModel := &PostgresqlDatabase{
Version: tools.GetPostgresqlVersionEnum(tc.version),
Host: container.Host,
Port: container.Port,
Username: limitedAdminUsername,
Password: limitedAdminPassword,
Database: &container.Database,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.Error(
t,
err,
"CreateReadOnlyUser should fail when admin lacks permissions to secure public schema",
)
if err != nil {
errorMsg := err.Error()
hasExpectedError := strings.Contains(
errorMsg,
"failed to revoke CREATE from PUBLIC on existing public schema",
) ||
strings.Contains(errorMsg, "permission denied for schema public") ||
strings.Contains(errorMsg, "failed to grant")
assert.True(
t,
hasExpectedError,
"Error should indicate permission issues with public schema, got: %s",
errorMsg,
)
}
assert.Empty(t, username)
assert.Empty(t, password)
})
}
}
func Test_Validate_WhenLocalhostAndDatabasus_ReturnsError(t *testing.T) {
testCases := []struct {
name string
host string
username string
database string
}{
{
name: "localhost with databasus db",
host: "localhost",
username: "postgres",
database: "databasus",
},
{
name: "127.0.0.1 with databasus db",
host: "127.0.0.1",
username: "postgres",
database: "databasus",
},
{
name: "172.17.0.1 with databasus db",
host: "172.17.0.1",
username: "postgres",
database: "databasus",
},
{
name: "host.docker.internal with databasus db",
host: "host.docker.internal",
username: "postgres",
database: "databasus",
},
{
name: "LOCALHOST (uppercase) with DATABASUS db",
host: "LOCALHOST",
username: "POSTGRES",
database: "DATABASUS",
},
{
name: "LocalHost (mixed case) with DataBasus db",
host: "LocalHost",
username: "anyuser",
database: "DataBasus",
},
{
name: "localhost with databasus and any username",
host: "localhost",
username: "myuser",
database: "databasus",
},
{
name: "::1 (IPv6 loopback) with databasus db",
host: "::1",
username: "postgres",
database: "databasus",
},
{
name: ":: (IPv6 all interfaces) with databasus db",
host: "::",
username: "postgres",
database: "databasus",
},
{
name: "::1 (uppercase) with DATABASUS db",
host: "::1",
username: "POSTGRES",
database: "DATABASUS",
},
{
name: "0.0.0.0 (all IPv4 interfaces) with databasus db",
host: "0.0.0.0",
username: "postgres",
database: "databasus",
},
{
name: "127.0.0.2 (loopback range) with databasus db",
host: "127.0.0.2",
username: "postgres",
database: "databasus",
},
{
name: "127.255.255.255 (end of loopback range) with databasus db",
host: "127.255.255.255",
username: "postgres",
database: "databasus",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
pgModel := &PostgresqlDatabase{
Host: tc.host,
Port: 5437,
Username: tc.username,
Password: "somepassword",
Database: &tc.database,
CpuCount: 1,
}
err := pgModel.Validate()
assert.Error(t, err)
assert.Contains(t, err.Error(), "backing up Databasus internal database is not allowed")
assert.Contains(t, err.Error(), "https://databasus.com/faq#backup-databasus")
})
}
}
func Test_Validate_WhenNotLocalhostOrNotDatabasus_ValidatesSuccessfully(t *testing.T) {
testCases := []struct {
name string
host string
username string
database string
}{
{
name: "different host (remote server) with databasus db",
host: "192.168.1.100",
username: "postgres",
database: "databasus",
},
{
name: "different database name on localhost",
host: "localhost",
username: "postgres",
database: "myapp",
},
{
name: "all different",
host: "db.example.com",
username: "appuser",
database: "production",
},
{
name: "localhost with postgres database",
host: "localhost",
username: "postgres",
database: "postgres",
},
{
name: "remote host with databasus db name (allowed)",
host: "db.example.com",
username: "postgres",
database: "databasus",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
pgModel := &PostgresqlDatabase{
Host: tc.host,
Port: 5432,
Username: tc.username,
Password: "somepassword",
Database: &tc.database,
CpuCount: 1,
}
err := pgModel.Validate()
assert.NoError(t, err)
})
}
}
func Test_Validate_WhenDatabaseIsNil_ValidatesSuccessfully(t *testing.T) {
pgModel := &PostgresqlDatabase{
Host: "localhost",
Port: 5437,
Username: "postgres",
Password: "somepassword",
Database: nil,
CpuCount: 1,
}
err := pgModel.Validate()
assert.NoError(t, err)
}
func Test_Validate_WhenDatabaseIsEmpty_ValidatesSuccessfully(t *testing.T) {
emptyDb := ""
pgModel := &PostgresqlDatabase{
Host: "localhost",
Port: 5437,
Username: "postgres",
Password: "somepassword",
Database: &emptyDb,
CpuCount: 1,
}
err := pgModel.Validate()
assert.NoError(t, err)
}
func Test_Validate_WhenRequiredFieldsMissing_ReturnsError(t *testing.T) {
testCases := []struct {
name string
model *PostgresqlDatabase
expectedError string
}{
{
name: "missing host",
model: &PostgresqlDatabase{
Host: "",
Port: 5432,
Username: "user",
Password: "pass",
CpuCount: 1,
},
expectedError: "host is required",
},
{
name: "missing port",
model: &PostgresqlDatabase{
Host: "localhost",
Port: 0,
Username: "user",
Password: "pass",
CpuCount: 1,
},
expectedError: "port is required",
},
{
name: "missing username",
model: &PostgresqlDatabase{
Host: "localhost",
Port: 5432,
Username: "",
Password: "pass",
CpuCount: 1,
},
expectedError: "username is required",
},
{
name: "missing password",
model: &PostgresqlDatabase{
Host: "localhost",
Port: 5432,
Username: "user",
Password: "",
CpuCount: 1,
},
expectedError: "password is required",
},
{
name: "invalid cpu count",
model: &PostgresqlDatabase{
Host: "localhost",
Port: 5432,
Username: "user",
Password: "pass",
CpuCount: 0,
},
expectedError: "cpu count must be greater than 0",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.model.Validate()
assert.Error(t, err)
assert.Contains(t, err.Error(), tc.expectedError)
})
}
}
type PostgresContainer struct {
Host string
Port int
@@ -718,7 +1323,7 @@ func connectToPostgresContainer(t *testing.T, port string) *PostgresContainer {
dbName := "testdb"
password := "testpassword"
username := "testuser"
host := "localhost"
host := config.GetEnv().TestLocalhost
portInt, err := strconv.Atoi(port)
assert.NoError(t, err)

View File

@@ -1,6 +1,9 @@
package databases
import (
"sync"
"sync/atomic"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/notifiers"
users_services "databasus-backend/internal/features/users/services"
@@ -37,7 +40,22 @@ func GetDatabaseController() *DatabaseController {
return databaseController
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(databaseService)
notifiers.GetNotifierService().SetNotifierDatabaseCounter(databaseService)
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(databaseService)
notifiers.GetNotifierService().SetNotifierDatabaseCounter(databaseService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -25,7 +25,7 @@ func GetTestPostgresConfig() *postgresql.PostgresqlDatabase {
testDbName := "testdb"
return &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",
@@ -48,7 +48,7 @@ func GetTestMariadbConfig() *mariadb.MariadbDatabase {
testDbName := "testdb"
return &mariadb.MariadbDatabase{
Version: tools.MariadbVersion1011,
Host: "localhost",
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",
@@ -69,7 +69,7 @@ func GetTestMongodbConfig() *mongodb.MongodbDatabase {
return &mongodb.MongodbDatabase{
Version: tools.MongodbVersion7,
Host: "localhost",
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "root",
Password: "rootpassword",

View File

@@ -2,30 +2,47 @@ package healthcheck_attempt
import (
"context"
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
)
type HealthcheckAttemptBackgroundService struct {
healthcheckConfigService *healthcheck_config.HealthcheckConfigService
checkDatabaseHealthUseCase *CheckDatabaseHealthUseCase
logger *slog.Logger
runOnce sync.Once
hasRun atomic.Bool
}
func (s *HealthcheckAttemptBackgroundService) Run(ctx context.Context) {
// first healthcheck immediately
s.checkDatabases()
wasAlreadyRun := s.hasRun.Load()
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.checkDatabases()
s.runOnce.Do(func() {
s.hasRun.Store(true)
// first healthcheck immediately
s.checkDatabases()
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.checkDatabases()
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
}

View File

@@ -1,6 +1,9 @@
package healthcheck_attempt
import (
"sync"
"sync/atomic"
"databasus-backend/internal/features/databases"
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
"databasus-backend/internal/features/notifiers"
@@ -22,9 +25,11 @@ var checkDatabaseHealthUseCase = &CheckDatabaseHealthUseCase{
}
var healthcheckAttemptBackgroundService = &HealthcheckAttemptBackgroundService{
healthcheck_config.GetHealthcheckConfigService(),
checkDatabaseHealthUseCase,
logger.GetLogger(),
healthcheckConfigService: healthcheck_config.GetHealthcheckConfigService(),
checkDatabaseHealthUseCase: checkDatabaseHealthUseCase,
logger: logger.GetLogger(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
var healthcheckAttemptController = &HealthcheckAttemptController{
healthcheckAttemptService,

View File

@@ -1,6 +1,9 @@
package healthcheck_config
import (
"sync"
"sync/atomic"
"databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/databases"
workspaces_services "databasus-backend/internal/features/workspaces/services"
@@ -27,8 +30,23 @@ func GetHealthcheckConfigController() *HealthcheckConfigController {
return healthcheckConfigController
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
databases.
GetDatabaseService().
AddDbCreationListener(healthcheckConfigService)
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
databases.
GetDatabaseService().
AddDbCreationListener(healthcheckConfigService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -1,6 +1,9 @@
package notifiers
import (
"sync"
"sync/atomic"
audit_logs "databasus-backend/internal/features/audit_logs"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/encryption"
@@ -32,6 +35,22 @@ func GetNotifierService() *NotifierService {
func GetNotifierRepository() *NotifierRepository {
return notifierRepository
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(notifierService)
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(notifierService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"log/slog"
"mime"
"net"
"net/smtp"
"time"
@@ -115,16 +116,34 @@ func (e *EmailNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor
return nil
}
// encodeRFC2047 encodes a string using RFC 2047 MIME encoding for email headers
// This ensures compatibility with SMTP servers that don't support SMTPUTF8
func encodeRFC2047(s string) string {
// mime.QEncoding handles UTF-8 → =?UTF-8?Q?...?= encoding
// This allows non-ASCII characters (emojis, accents, etc.) in email headers
// while maintaining compatibility with all SMTP servers
return mime.QEncoding.Encode("UTF-8", s)
}
func (e *EmailNotifier) buildEmailContent(heading, message, from string) []byte {
subject := fmt.Sprintf("Subject: %s\r\n", heading)
mime := fmt.Sprintf(
// Encode Subject header using RFC 2047 to avoid SMTPUTF8 requirement
// This ensures compatibility with SMTP servers that don't support SMTPUTF8
encodedSubject := encodeRFC2047(heading)
subject := fmt.Sprintf("Subject: %s\r\n", encodedSubject)
mimeHeaders := fmt.Sprintf(
"MIME-version: 1.0;\nContent-Type: %s; charset=\"%s\";\n\n",
MIMETypeHTML,
MIMECharsetUTF8,
)
fromHeader := fmt.Sprintf("From: %s\r\n", from)
// Encode From header display name if it contains non-ASCII
encodedFrom := encodeRFC2047(from)
fromHeader := fmt.Sprintf("From: %s\r\n", encodedFrom)
toHeader := fmt.Sprintf("To: %s\r\n", e.TargetEmail)
return []byte(fromHeader + toHeader + subject + mime + message)
return []byte(fromHeader + toHeader + subject + mimeHeaders + message)
}
func (e *EmailNotifier) sendImplicitTLS(

View File

@@ -1,38 +0,0 @@
package restores
import (
"context"
"databasus-backend/internal/features/restores/enums"
"log/slog"
)
type RestoreBackgroundService struct {
restoreRepository *RestoreRepository
logger *slog.Logger
}
func (s *RestoreBackgroundService) Run(ctx context.Context) {
if err := s.failRestoresInProgress(); err != nil {
s.logger.Error("Failed to fail restores in progress", "error", err)
panic(err)
}
}
func (s *RestoreBackgroundService) failRestoresInProgress() error {
restoresInProgress, err := s.restoreRepository.FindByStatus(enums.RestoreStatusInProgress)
if err != nil {
return err
}
for _, restore := range restoresInProgress {
failMessage := "Restore failed due to application restart"
restore.Status = enums.RestoreStatusFailed
restore.FailMessage = &failMessage
if err := s.restoreRepository.Save(restore); err != nil {
return err
}
}
return nil
}

View File

@@ -1,6 +1,7 @@
package restores
import (
restores_core "databasus-backend/internal/features/restores/core"
users_middleware "databasus-backend/internal/features/users/middleware"
"net/http"
@@ -15,6 +16,7 @@ type RestoreController struct {
func (c *RestoreController) RegisterRoutes(router *gin.RouterGroup) {
router.GET("/restores/:backupId", c.GetRestores)
router.POST("/restores/:backupId/restore", c.RestoreBackup)
router.POST("/restores/cancel/:restoreId", c.CancelRestore)
}
// GetRestores
@@ -23,7 +25,7 @@ func (c *RestoreController) RegisterRoutes(router *gin.RouterGroup) {
// @Tags restores
// @Produce json
// @Param backupId path string true "Backup ID"
// @Success 200 {array} models.Restore
// @Success 200 {array} restores_core.Restore
// @Failure 400
// @Failure 401
// @Router /restores/{backupId} [get]
@@ -71,7 +73,7 @@ func (c *RestoreController) RestoreBackup(ctx *gin.Context) {
return
}
var requestDTO RestoreBackupRequest
var requestDTO restores_core.RestoreBackupRequest
if err := ctx.ShouldBindJSON(&requestDTO); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -84,3 +86,33 @@ func (c *RestoreController) RestoreBackup(ctx *gin.Context) {
ctx.JSON(http.StatusOK, gin.H{"message": "restore started successfully"})
}
// CancelRestore
// @Summary Cancel an in-progress restore
// @Description Cancel a restore that is currently in progress
// @Tags restores
// @Param restoreId path string true "Restore ID"
// @Success 204
// @Failure 400
// @Failure 401
// @Router /restores/cancel/{restoreId} [post]
func (c *RestoreController) CancelRestore(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
restoreID, err := uuid.Parse(ctx.Param("restoreId"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid restore ID"})
return
}
if err := c.restoreService.CancelRestore(user, restoreID); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.Status(http.StatusNoContent)
}

View File

@@ -16,7 +16,7 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
env_config "databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
@@ -24,16 +24,19 @@ import (
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/mysql"
"databasus-backend/internal/features/databases/databases/postgresql"
"databasus-backend/internal/features/restores/models"
"databasus-backend/internal/features/notifiers"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/restores/restoring"
"databasus-backend/internal/features/storages"
local_storage "databasus-backend/internal/features/storages/models/local"
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
users_dto "databasus-backend/internal/features/users/dto"
users_enums "databasus-backend/internal/features/users/enums"
users_services "databasus-backend/internal/features/users/services"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_models "databasus-backend/internal/features/workspaces/models"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
cache_utils "databasus-backend/internal/util/cache"
util_encryption "databasus-backend/internal/util/encryption"
test_utils "databasus-backend/internal/util/testing"
"databasus-backend/internal/util/tools"
@@ -46,7 +49,7 @@ func Test_GetRestores_WhenUserIsWorkspaceMember_RestoresReturned(t *testing.T) {
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
var restores []*models.Restore
var restores []*restores_core.Restore
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
@@ -90,7 +93,7 @@ func Test_GetRestores_WhenUserIsGlobalAdmin_RestoresReturned(t *testing.T) {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
var restores []*models.Restore
var restores []*restores_core.Restore
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
@@ -105,15 +108,19 @@ func Test_GetRestores_WhenUserIsGlobalAdmin_RestoresReturned(t *testing.T) {
func Test_RestoreBackup_WhenUserIsWorkspaceMember_RestoreInitiated(t *testing.T) {
router := createTestRouter()
_, cleanup := SetupMockRestoreNode(t)
defer cleanup()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
request := RestoreBackupRequest{
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Host: env_config.GetEnv().TestLocalhost,
Port: 5432,
Username: "postgres",
Password: "postgres",
@@ -141,10 +148,10 @@ func Test_RestoreBackup_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
request := RestoreBackupRequest{
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Host: env_config.GetEnv().TestLocalhost,
Port: 5432,
Username: "postgres",
Password: "postgres",
@@ -165,15 +172,19 @@ func Test_RestoreBackup_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing
func Test_RestoreBackup_WithIsExcludeExtensions_FlagPassedCorrectly(t *testing.T) {
router := createTestRouter()
_, cleanup := SetupMockRestoreNode(t)
defer cleanup()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
request := RestoreBackupRequest{
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Host: env_config.GetEnv().TestLocalhost,
Port: 5432,
Username: "postgres",
Password: "postgres",
@@ -195,15 +206,19 @@ func Test_RestoreBackup_WithIsExcludeExtensions_FlagPassedCorrectly(t *testing.T
func Test_RestoreBackup_AuditLogWritten(t *testing.T) {
router := createTestRouter()
_, cleanup := SetupMockRestoreNode(t)
defer cleanup()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
request := RestoreBackupRequest{
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Host: env_config.GetEnv().TestLocalhost,
Port: 5432,
Username: "postgres",
Password: "postgres",
@@ -272,18 +287,25 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
router := createTestRouter()
// Setup mock node for tests that skip disk validation and reach scheduler
if !tc.expectDiskValidated {
_, cleanup := SetupMockRestoreNode(t)
defer cleanup()
}
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
var backup *backups_core.Backup
var request RestoreBackupRequest
var request restores_core.RestoreBackupRequest
if tc.dbType == databases.DatabaseTypePostgres {
_, backup = createTestDatabaseWithBackupForRestore(workspace, owner, router)
request = RestoreBackupRequest{
request = restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Host: env_config.GetEnv().TestLocalhost,
Port: 5432,
Username: "postgres",
Password: "postgres",
@@ -310,10 +332,10 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
assert.NoError(t, err)
backup = createTestBackup(mysqlDB, owner)
request = RestoreBackupRequest{
request = restores_core.RestoreBackupRequest{
MysqlDatabase: &mysql.MysqlDatabase{
Version: tools.MysqlVersion80,
Host: "localhost",
Host: env_config.GetEnv().TestLocalhost,
Port: 3306,
Username: "root",
Password: "password",
@@ -353,16 +375,187 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
}
}
func createTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
backups_config.GetBackupConfigController(),
backups.GetBackupController(),
GetRestoreController(),
func Test_CancelRestore_InProgressRestore_SuccessfullyCancelled(t *testing.T) {
cache_utils.ClearAllCache()
tasks_cancellation.SetupDependencies()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := createTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backupRepo := backups_core.BackupRepository{}
backups, _ := backupRepo.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepo.DeleteByID(backup.ID)
}
restoreRepo := restores_core.RestoreRepository{}
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
for _, restore := range restores {
restoreRepo.DeleteByID(restore.ID)
}
restores, _ = restoreRepo.FindByStatus(restores_core.RestoreStatusCanceled)
for _, restore := range restores {
restoreRepo.DeleteByID(restore.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
cache_utils.ClearAllCache()
}()
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
backup := backups.CreateTestBackup(database.ID, storage.ID)
mockUsecase := &restoring.MockBlockingRestoreUsecase{
StartedChan: make(chan bool, 1),
}
restorerNode := restoring.CreateTestRestorerNodeWithUsecase(mockUsecase)
cancelNode := restoring.StartRestorerNodeForTest(t, restorerNode)
defer cancelNode()
time.Sleep(200 * time.Millisecond)
restoreRequest := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: env_config.GetEnv().TestLocalhost,
Port: 5432,
Username: "postgres",
Password: "postgres",
},
}
var restoreResponse map[string]interface{}
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
"Bearer "+user.Token,
restoreRequest,
http.StatusOK,
&restoreResponse,
)
return router
select {
case <-mockUsecase.StartedChan:
t.Log("Restore started and is blocking")
case <-time.After(2 * time.Second):
t.Fatal("Restore did not start within timeout")
}
restoreRepo := &restores_core.RestoreRepository{}
restores, err := restoreRepo.FindByBackupID(backup.ID)
assert.NoError(t, err)
assert.Greater(t, len(restores), 0, "At least one restore should exist")
var restoreID uuid.UUID
for _, r := range restores {
if r.Status == restores_core.RestoreStatusInProgress {
restoreID = r.ID
break
}
}
assert.NotEqual(t, uuid.Nil, restoreID, "Should find an in-progress restore")
resp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/cancel/%s", restoreID.String()),
"Bearer "+user.Token,
nil,
http.StatusNoContent,
)
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
deadline := time.Now().UTC().Add(3 * time.Second)
var restore *restores_core.Restore
for time.Now().UTC().Before(deadline) {
restore, err = restoreRepo.FindByID(restoreID)
assert.NoError(t, err)
if restore.Status == restores_core.RestoreStatusCanceled {
break
}
time.Sleep(100 * time.Millisecond)
}
assert.Equal(t, restores_core.RestoreStatusCanceled, restore.Status)
auditLogService := audit_logs.GetAuditLogService()
auditLogs, err := auditLogService.GetWorkspaceAuditLogs(
workspace.ID,
&audit_logs.GetAuditLogsRequest{Limit: 100, Offset: 0},
)
assert.NoError(t, err)
foundCancelLog := false
for _, log := range auditLogs.AuditLogs {
if strings.Contains(log.Message, "Restore cancelled") &&
strings.Contains(log.Message, database.Name) {
foundCancelLog = true
break
}
}
assert.True(t, foundCancelLog, "Cancel audit log should be created")
time.Sleep(200 * time.Millisecond)
}
func Test_RestoreBackup_WithParallelRestoreInProgress_ReturnsError(t *testing.T) {
router := createTestRouter()
_, cleanup := SetupMockRestoreNode(t)
defer cleanup()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: env_config.GetEnv().TestLocalhost,
Port: 5432,
Username: "postgres",
Password: "postgres",
},
}
testResp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
"Bearer "+owner.Token,
request,
http.StatusOK,
)
assert.Contains(t, string(testResp.Body), "restore started successfully")
testResp2 := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
"Bearer "+owner.Token,
request,
http.StatusBadRequest,
)
assert.Contains(t, string(testResp2.Body), "another restore is already in progress")
}
func createTestRouter() *gin.Engine {
return CreateTestRouter()
}
func createTestDatabaseWithBackupForRestore(
@@ -433,7 +626,7 @@ func createTestMySQLDatabase(
token string,
router *gin.Engine,
) *databases.Database {
env := config.GetEnv()
env := env_config.GetEnv()
portStr := env.TestMysql80Port
if portStr == "" {
portStr = "33080"
@@ -451,7 +644,7 @@ func createTestMySQLDatabase(
Type: databases.DatabaseTypeMysql,
Mysql: &mysql.MysqlDatabase{
Version: tools.MysqlVersion80,
Host: "localhost",
Host: env_config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",

View File

@@ -1,4 +1,4 @@
package restores
package restores_core
import (
"databasus-backend/internal/features/databases/databases/mariadb"

View File

@@ -1,4 +1,4 @@
package enums
package restores_core
type RestoreStatus string
@@ -6,4 +6,5 @@ const (
RestoreStatusInProgress RestoreStatus = "IN_PROGRESS"
RestoreStatusCompleted RestoreStatus = "COMPLETED"
RestoreStatusFailed RestoreStatus = "FAILED"
RestoreStatusCanceled RestoreStatus = "CANCELED"
)

View File

@@ -0,0 +1,23 @@
package restores_core
import (
"context"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/storages"
)
type RestoreBackupUsecase interface {
Execute(
ctx context.Context,
backupConfig *backups_config.BackupConfig,
restore Restore,
originalDB *databases.Database,
restoringToDB *databases.Database,
backup *backups_core.Backup,
storage *storages.Storage,
isExcludeExtensions bool,
) error
}

View File

@@ -0,0 +1,30 @@
package restores_core
import (
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/databases/databases/mariadb"
"databasus-backend/internal/features/databases/databases/mongodb"
"databasus-backend/internal/features/databases/databases/mysql"
"databasus-backend/internal/features/databases/databases/postgresql"
"time"
"github.com/google/uuid"
)
type Restore struct {
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
Status RestoreStatus `json:"status" gorm:"column:status;type:text;not null"`
BackupID uuid.UUID `json:"backupId" gorm:"column:backup_id;type:uuid;not null"`
Backup *backups_core.Backup
PostgresqlDatabase *postgresql.PostgresqlDatabase `json:"postgresqlDatabase" gorm:"-"`
MysqlDatabase *mysql.MysqlDatabase `json:"mysqlDatabase" gorm:"-"`
MariadbDatabase *mariadb.MariadbDatabase `json:"mariadbDatabase" gorm:"-"`
MongodbDatabase *mongodb.MongodbDatabase `json:"mongodbDatabase" gorm:"-"`
FailMessage *string `json:"failMessage" gorm:"column:fail_message"`
RestoreDurationMs int64 `json:"restoreDurationMs" gorm:"column:restore_duration_ms;default:0"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;default:now()"`
}

View File

@@ -0,0 +1,91 @@
package restores_core
import (
"databasus-backend/internal/storage"
"github.com/google/uuid"
)
type RestoreRepository struct{}
func (r *RestoreRepository) Save(restore *Restore) error {
db := storage.GetDb()
isNew := restore.ID == uuid.Nil
if isNew {
restore.ID = uuid.New()
return db.Create(restore).
Omit("Backup", "PostgresqlDatabase", "MysqlDatabase", "MariadbDatabase", "MongodbDatabase").
Error
}
return db.Save(restore).
Omit("Backup", "PostgresqlDatabase", "MysqlDatabase", "MariadbDatabase", "MongodbDatabase").
Error
}
func (r *RestoreRepository) FindByBackupID(backupID uuid.UUID) ([]*Restore, error) {
var restores []*Restore
if err := storage.
GetDb().
Preload("Backup").
Where("backup_id = ?", backupID).
Order("created_at DESC").
Find(&restores).Error; err != nil {
return nil, err
}
return restores, nil
}
func (r *RestoreRepository) FindByID(id uuid.UUID) (*Restore, error) {
var restore Restore
if err := storage.
GetDb().
Preload("Backup").
Where("id = ?", id).
First(&restore).Error; err != nil {
return nil, err
}
return &restore, nil
}
func (r *RestoreRepository) FindByStatus(status RestoreStatus) ([]*Restore, error) {
var restores []*Restore
if err := storage.
GetDb().
Preload("Backup").
Where("status = ?", status).
Order("created_at DESC").
Find(&restores).Error; err != nil {
return nil, err
}
return restores, nil
}
func (r *RestoreRepository) FindInProgressRestoresByDatabaseID(
databaseID uuid.UUID,
) ([]*Restore, error) {
var restores []*Restore
if err := storage.
GetDb().
Preload("Backup").
Joins("JOIN backups ON backups.id = restores.backup_id").
Where("backups.database_id = ? AND restores.status = ?", databaseID, RestoreStatusInProgress).
Order("restores.created_at DESC").
Find(&restores).Error; err != nil {
return nil, err
}
return restores, nil
}
func (r *RestoreRepository) DeleteByID(id uuid.UUID) error {
return storage.GetDb().Delete(&Restore{}, "id = ?", id).Error
}

View File

@@ -1,19 +1,25 @@
package restores
import (
"sync"
"sync/atomic"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
"databasus-backend/internal/features/backups/backups/backuping"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/disk"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/restores/usecases"
"databasus-backend/internal/features/storages"
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
)
var restoreRepository = &RestoreRepository{}
var restoreRepository = &restores_core.RestoreRepository{}
var restoreService = &RestoreService{
backups.GetBackupService(),
restoreRepository,
@@ -26,24 +32,32 @@ var restoreService = &RestoreService{
audit_logs.GetAuditLogService(),
encryption.GetFieldEncryptor(),
disk.GetDiskService(),
tasks_cancellation.GetTaskCancelManager(),
}
var restoreController = &RestoreController{
restoreService,
}
var restoreBackgroundService = &RestoreBackgroundService{
restoreRepository,
logger.GetLogger(),
}
func GetRestoreController() *RestoreController {
return restoreController
}
func GetRestoreBackgroundService() *RestoreBackgroundService {
return restoreBackgroundService
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
backups.GetBackupService().AddBackupRemoveListener(restoreService)
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
backups.GetBackupService().AddBackupRemoveListener(restoreService)
backuping.GetBackupCleaner().AddBackupRemoveListener(restoreService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -1,22 +0,0 @@
package models
import (
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/restores/enums"
"time"
"github.com/google/uuid"
)
type Restore struct {
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
Status enums.RestoreStatus `json:"status" gorm:"column:status;type:text;not null"`
BackupID uuid.UUID `json:"backupId" gorm:"column:backup_id;type:uuid;not null"`
Backup *backups_core.Backup
FailMessage *string `json:"failMessage" gorm:"column:fail_message"`
RestoreDurationMs int64 `json:"restoreDurationMs" gorm:"column:restore_duration_ms;default:0"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;default:now()"`
}

View File

@@ -1,75 +0,0 @@
package restores
import (
"databasus-backend/internal/features/restores/enums"
"databasus-backend/internal/features/restores/models"
"databasus-backend/internal/storage"
"github.com/google/uuid"
)
type RestoreRepository struct{}
func (r *RestoreRepository) Save(restore *models.Restore) error {
db := storage.GetDb()
isNew := restore.ID == uuid.Nil
if isNew {
restore.ID = uuid.New()
return db.Create(restore).
Omit("Backup").
Error
}
return db.Save(restore).
Omit("Backup").
Error
}
func (r *RestoreRepository) FindByBackupID(backupID uuid.UUID) ([]*models.Restore, error) {
var restores []*models.Restore
if err := storage.
GetDb().
Preload("Backup").
Where("backup_id = ?", backupID).
Order("created_at DESC").
Find(&restores).Error; err != nil {
return nil, err
}
return restores, nil
}
func (r *RestoreRepository) FindByID(id uuid.UUID) (*models.Restore, error) {
var restore models.Restore
if err := storage.
GetDb().
Preload("Backup").
Where("id = ?", id).
First(&restore).Error; err != nil {
return nil, err
}
return &restore, nil
}
func (r *RestoreRepository) FindByStatus(status enums.RestoreStatus) ([]*models.Restore, error) {
var restores []*models.Restore
if err := storage.
GetDb().
Preload("Backup").
Where("status = ?", status).
Order("created_at DESC").
Find(&restores).Error; err != nil {
return nil, err
}
return restores, nil
}
func (r *RestoreRepository) DeleteByID(id uuid.UUID) error {
return storage.GetDb().Delete(&models.Restore{}, "id = ?", id).Error
}

View File

@@ -0,0 +1,85 @@
package restoring
import (
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"databasus-backend/internal/features/backups/backups"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/restores/usecases"
"databasus-backend/internal/features/storages"
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
)
var restoreRepository = &restores_core.RestoreRepository{}
var restoreNodesRegistry = &RestoreNodesRegistry{
client: cache_utils.GetValkeyClient(),
logger: logger.GetLogger(),
timeout: cache_utils.DefaultCacheTimeout,
pubsubRestores: cache_utils.NewPubSubManager(),
pubsubCompletions: cache_utils.NewPubSubManager(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
var restoreDatabaseCache = cache_utils.NewCacheUtil[RestoreDatabaseCache](
cache_utils.GetValkeyClient(),
"restore_db:",
)
var restoreCancelManager = tasks_cancellation.GetTaskCancelManager()
var restorerNode = &RestorerNode{
nodeID: uuid.New(),
databaseService: databases.GetDatabaseService(),
backupService: backups.GetBackupService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
restoreRepository: restoreRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
restoreNodesRegistry: restoreNodesRegistry,
logger: logger.GetLogger(),
restoreBackupUsecase: usecases.GetRestoreBackupUsecase(),
cacheUtil: restoreDatabaseCache,
restoreCancelManager: restoreCancelManager,
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
var restoresScheduler = &RestoresScheduler{
restoreRepository: restoreRepository,
backupService: backups.GetBackupService(),
storageService: storages.GetStorageService(),
backupConfigService: backups_config.GetBackupConfigService(),
restoreNodesRegistry: restoreNodesRegistry,
lastCheckTime: time.Now().UTC(),
logger: logger.GetLogger(),
restoreToNodeRelations: make(map[uuid.UUID]RestoreToNodeRelation),
restorerNode: restorerNode,
cacheUtil: restoreDatabaseCache,
completionSubscriptionID: uuid.Nil,
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
func GetRestoresScheduler() *RestoresScheduler {
return restoresScheduler
}
func GetRestorerNode() *RestorerNode {
return restorerNode
}
func GetRestoreNodesRegistry() *RestoreNodesRegistry {
return restoreNodesRegistry
}

View File

@@ -0,0 +1,45 @@
package restoring
import (
"databasus-backend/internal/features/databases/databases/mariadb"
"databasus-backend/internal/features/databases/databases/mongodb"
"databasus-backend/internal/features/databases/databases/mysql"
"databasus-backend/internal/features/databases/databases/postgresql"
"time"
"github.com/google/uuid"
)
type RestoreDatabaseCache struct {
PostgresqlDatabase *postgresql.PostgresqlDatabase `json:"postgresqlDatabase,omitempty"`
MysqlDatabase *mysql.MysqlDatabase `json:"mysqlDatabase,omitempty"`
MariadbDatabase *mariadb.MariadbDatabase `json:"mariadbDatabase,omitempty"`
MongodbDatabase *mongodb.MongodbDatabase `json:"mongodbDatabase,omitempty"`
}
type RestoreToNodeRelation struct {
NodeID uuid.UUID `json:"nodeId"`
RestoreIDs []uuid.UUID `json:"restoreIds"`
}
type RestoreNode struct {
ID uuid.UUID `json:"id"`
ThroughputMBs int `json:"throughputMBs"`
LastHeartbeat time.Time `json:"lastHeartbeat"`
}
type RestoreNodeStats struct {
ID uuid.UUID `json:"id"`
ActiveRestores int `json:"activeRestores"`
}
type RestoreSubmitMessage struct {
NodeID uuid.UUID `json:"nodeId"`
RestoreID uuid.UUID `json:"restoreId"`
IsCallNotifier bool `json:"isCallNotifier"`
}
type RestoreCompletionMessage struct {
NodeID uuid.UUID `json:"nodeId"`
RestoreID uuid.UUID `json:"restoreId"`
}

View File

@@ -0,0 +1,88 @@
package restoring
import (
"context"
"errors"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
)
type MockSuccessRestoreUsecase struct{}
func (uc *MockSuccessRestoreUsecase) Execute(
ctx context.Context,
backupConfig *backups_config.BackupConfig,
restore restores_core.Restore,
originalDB *databases.Database,
restoringToDB *databases.Database,
backup *backups_core.Backup,
storage *storages.Storage,
isExcludeExtensions bool,
) error {
return nil
}
type MockFailedRestoreUsecase struct{}
func (uc *MockFailedRestoreUsecase) Execute(
ctx context.Context,
backupConfig *backups_config.BackupConfig,
restore restores_core.Restore,
originalDB *databases.Database,
restoringToDB *databases.Database,
backup *backups_core.Backup,
storage *storages.Storage,
isExcludeExtensions bool,
) error {
return errors.New("restore failed")
}
type MockCaptureCredentialsRestoreUsecase struct {
CalledChan chan *databases.Database
ShouldSucceed bool
}
func (uc *MockCaptureCredentialsRestoreUsecase) Execute(
ctx context.Context,
backupConfig *backups_config.BackupConfig,
restore restores_core.Restore,
originalDB *databases.Database,
restoringToDB *databases.Database,
backup *backups_core.Backup,
storage *storages.Storage,
isExcludeExtensions bool,
) error {
uc.CalledChan <- restoringToDB
if uc.ShouldSucceed {
return nil
}
return errors.New("mock restore failed")
}
type MockBlockingRestoreUsecase struct {
StartedChan chan bool
}
func (uc *MockBlockingRestoreUsecase) Execute(
ctx context.Context,
backupConfig *backups_config.BackupConfig,
restore restores_core.Restore,
originalDB *databases.Database,
restoringToDB *databases.Database,
backup *backups_core.Backup,
storage *storages.Storage,
isExcludeExtensions bool,
) error {
if uc.StartedChan != nil {
uc.StartedChan <- true
}
<-ctx.Done()
return ctx.Err()
}

View File

@@ -0,0 +1,649 @@
package restoring
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"sync"
"sync/atomic"
"time"
cache_utils "databasus-backend/internal/util/cache"
"github.com/google/uuid"
"github.com/valkey-io/valkey-go"
)
const (
nodeInfoKeyPrefix = "restore:node:"
nodeInfoKeySuffix = ":info"
nodeActiveRestoresPrefix = "restore:node:"
nodeActiveRestoresSuffix = ":active_restores"
restoreSubmitChannel = "restore:submit"
restoreCompletionChannel = "restore:completion"
deadNodeThreshold = 2 * time.Minute
cleanupTickerInterval = 1 * time.Second
)
// RestoreNodesRegistry helps to sync restores scheduler and restore nodes.
//
// Features:
// - Track node availability and load level
// - Assign from scheduler to node restores needed to be processed
// - Notify scheduler from node about restore completion
//
// Important things to remember:
// - Nodes without heartbeat for more than 2 minutes are not included
// in available nodes list and stats
//
// Cleanup dead nodes performed on 2 levels:
// - List and stats functions do not return dead nodes
// - Periodically dead nodes are cleaned up in cache (to not
// accumulate too many dead nodes in cache)
type RestoreNodesRegistry struct {
client valkey.Client
logger *slog.Logger
timeout time.Duration
pubsubRestores *cache_utils.PubSubManager
pubsubCompletions *cache_utils.PubSubManager
runOnce sync.Once
hasRun atomic.Bool
}
func (r *RestoreNodesRegistry) Run(ctx context.Context) {
wasAlreadyRun := r.hasRun.Load()
r.runOnce.Do(func() {
r.hasRun.Store(true)
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
}
ticker := time.NewTicker(cleanupTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes", "error", err)
}
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", r))
}
}
func (r *RestoreNodesRegistry) GetAvailableNodes() ([]RestoreNode, error) {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
var allKeys []string
cursor := uint64(0)
pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix
for {
result := r.client.Do(
ctx,
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(),
)
if result.Error() != nil {
return nil, fmt.Errorf("failed to scan node keys: %w", result.Error())
}
scanResult, err := result.AsScanEntry()
if err != nil {
return nil, fmt.Errorf("failed to parse scan result: %w", err)
}
allKeys = append(allKeys, scanResult.Elements...)
cursor = scanResult.Cursor
if cursor == 0 {
break
}
}
if len(allKeys) == 0 {
return []RestoreNode{}, nil
}
keyDataMap, err := r.pipelineGetKeys(allKeys)
if err != nil {
return nil, fmt.Errorf("failed to pipeline get node keys: %w", err)
}
threshold := time.Now().UTC().Add(-deadNodeThreshold)
var nodes []RestoreNode
for key, data := range keyDataMap {
// Skip if the key doesn't exist (data is empty)
if len(data) == 0 {
continue
}
var node RestoreNode
if err := json.Unmarshal(data, &node); err != nil {
r.logger.Warn("Failed to unmarshal node data", "key", key, "error", err)
continue
}
// Skip nodes with zero/uninitialized heartbeat
if node.LastHeartbeat.IsZero() {
continue
}
if node.LastHeartbeat.Before(threshold) {
continue
}
nodes = append(nodes, node)
}
return nodes, nil
}
func (r *RestoreNodesRegistry) GetRestoreNodesStats() ([]RestoreNodeStats, error) {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
var allKeys []string
cursor := uint64(0)
pattern := nodeActiveRestoresPrefix + "*" + nodeActiveRestoresSuffix
for {
result := r.client.Do(
ctx,
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(100).Build(),
)
if result.Error() != nil {
return nil, fmt.Errorf("failed to scan active restores keys: %w", result.Error())
}
scanResult, err := result.AsScanEntry()
if err != nil {
return nil, fmt.Errorf("failed to parse scan result: %w", err)
}
allKeys = append(allKeys, scanResult.Elements...)
cursor = scanResult.Cursor
if cursor == 0 {
break
}
}
if len(allKeys) == 0 {
return []RestoreNodeStats{}, nil
}
keyDataMap, err := r.pipelineGetKeys(allKeys)
if err != nil {
return nil, fmt.Errorf("failed to pipeline get active restores keys: %w", err)
}
var nodeInfoKeys []string
nodeIDToStatsKey := make(map[string]string)
for key := range keyDataMap {
nodeID := r.extractNodeIDFromKey(key, nodeActiveRestoresPrefix, nodeActiveRestoresSuffix)
nodeIDStr := nodeID.String()
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeIDStr, nodeInfoKeySuffix)
nodeInfoKeys = append(nodeInfoKeys, infoKey)
nodeIDToStatsKey[infoKey] = key
}
nodeInfoMap, err := r.pipelineGetKeys(nodeInfoKeys)
if err != nil {
return nil, fmt.Errorf("failed to pipeline get node info keys: %w", err)
}
threshold := time.Now().UTC().Add(-deadNodeThreshold)
var stats []RestoreNodeStats
for infoKey, nodeData := range nodeInfoMap {
// Skip if the info key doesn't exist (nodeData is empty)
if len(nodeData) == 0 {
continue
}
var node RestoreNode
if err := json.Unmarshal(nodeData, &node); err != nil {
r.logger.Warn("Failed to unmarshal node data", "key", infoKey, "error", err)
continue
}
// Skip nodes with zero/uninitialized heartbeat
if node.LastHeartbeat.IsZero() {
continue
}
if node.LastHeartbeat.Before(threshold) {
continue
}
statsKey := nodeIDToStatsKey[infoKey]
tasksData := keyDataMap[statsKey]
count, err := r.parseIntFromBytes(tasksData)
if err != nil {
r.logger.Warn("Failed to parse active restores count", "key", statsKey, "error", err)
continue
}
stat := RestoreNodeStats{
ID: node.ID,
ActiveRestores: int(count),
}
stats = append(stats, stat)
}
return stats, nil
}
func (r *RestoreNodesRegistry) IncrementRestoresInProgress(nodeID uuid.UUID) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
key := fmt.Sprintf(
"%s%s%s",
nodeActiveRestoresPrefix,
nodeID.String(),
nodeActiveRestoresSuffix,
)
result := r.client.Do(ctx, r.client.B().Incr().Key(key).Build())
if result.Error() != nil {
return fmt.Errorf(
"failed to increment restores in progress for node %s: %w",
nodeID,
result.Error(),
)
}
return nil
}
func (r *RestoreNodesRegistry) DecrementRestoresInProgress(nodeID uuid.UUID) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
key := fmt.Sprintf(
"%s%s%s",
nodeActiveRestoresPrefix,
nodeID.String(),
nodeActiveRestoresSuffix,
)
result := r.client.Do(ctx, r.client.B().Decr().Key(key).Build())
if result.Error() != nil {
return fmt.Errorf(
"failed to decrement restores in progress for node %s: %w",
nodeID,
result.Error(),
)
}
newValue, err := result.AsInt64()
if err != nil {
return fmt.Errorf("failed to parse decremented value for node %s: %w", nodeID, err)
}
if newValue < 0 {
setCtx, setCancel := context.WithTimeout(context.Background(), r.timeout)
r.client.Do(setCtx, r.client.B().Set().Key(key).Value("0").Build())
setCancel()
r.logger.Warn("Active restores counter went below 0, reset to 0", "nodeID", nodeID)
}
return nil
}
func (r *RestoreNodesRegistry) HearthbeatNodeInRegistry(
now time.Time,
restoreNode RestoreNode,
) error {
if now.IsZero() {
return fmt.Errorf("cannot register node with zero heartbeat timestamp")
}
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
restoreNode.LastHeartbeat = now
data, err := json.Marshal(restoreNode)
if err != nil {
return fmt.Errorf("failed to marshal restore node: %w", err)
}
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, restoreNode.ID.String(), nodeInfoKeySuffix)
result := r.client.Do(
ctx,
r.client.B().Set().Key(key).Value(string(data)).Build(),
)
if result.Error() != nil {
return fmt.Errorf("failed to register node %s: %w", restoreNode.ID, result.Error())
}
return nil
}
func (r *RestoreNodesRegistry) UnregisterNodeFromRegistry(restoreNode RestoreNode) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, restoreNode.ID.String(), nodeInfoKeySuffix)
counterKey := fmt.Sprintf(
"%s%s%s",
nodeActiveRestoresPrefix,
restoreNode.ID.String(),
nodeActiveRestoresSuffix,
)
result := r.client.Do(
ctx,
r.client.B().Del().Key(infoKey, counterKey).Build(),
)
if result.Error() != nil {
return fmt.Errorf("failed to unregister node %s: %w", restoreNode.ID, result.Error())
}
r.logger.Info("Unregistered node from registry", "nodeID", restoreNode.ID)
return nil
}
func (r *RestoreNodesRegistry) AssignRestoreToNode(
targetNodeID uuid.UUID,
restoreID uuid.UUID,
isCallNotifier bool,
) error {
ctx := context.Background()
message := RestoreSubmitMessage{
NodeID: targetNodeID,
RestoreID: restoreID,
IsCallNotifier: isCallNotifier,
}
messageJSON, err := json.Marshal(message)
if err != nil {
return fmt.Errorf("failed to marshal restore submit message: %w", err)
}
err = r.pubsubRestores.Publish(ctx, restoreSubmitChannel, string(messageJSON))
if err != nil {
return fmt.Errorf("failed to publish restore submit message: %w", err)
}
return nil
}
func (r *RestoreNodesRegistry) SubscribeNodeForRestoresAssignment(
nodeID uuid.UUID,
handler func(restoreID uuid.UUID, isCallNotifier bool),
) error {
ctx := context.Background()
wrappedHandler := func(message string) {
var msg RestoreSubmitMessage
if err := json.Unmarshal([]byte(message), &msg); err != nil {
r.logger.Warn("Failed to unmarshal restore submit message", "error", err)
return
}
if msg.NodeID != nodeID {
return
}
handler(msg.RestoreID, msg.IsCallNotifier)
}
err := r.pubsubRestores.Subscribe(ctx, restoreSubmitChannel, wrappedHandler)
if err != nil {
return fmt.Errorf("failed to subscribe to restore submit channel: %w", err)
}
r.logger.Info("Subscribed to restore submit channel", "nodeID", nodeID)
return nil
}
func (r *RestoreNodesRegistry) UnsubscribeNodeForRestoresAssignments() error {
err := r.pubsubRestores.Close()
if err != nil {
return fmt.Errorf("failed to unsubscribe from restore submit channel: %w", err)
}
r.logger.Info("Unsubscribed from restore submit channel")
return nil
}
func (r *RestoreNodesRegistry) PublishRestoreCompletion(
nodeID uuid.UUID,
restoreID uuid.UUID,
) error {
ctx := context.Background()
message := RestoreCompletionMessage{
NodeID: nodeID,
RestoreID: restoreID,
}
messageJSON, err := json.Marshal(message)
if err != nil {
return fmt.Errorf("failed to marshal restore completion message: %w", err)
}
err = r.pubsubCompletions.Publish(ctx, restoreCompletionChannel, string(messageJSON))
if err != nil {
return fmt.Errorf("failed to publish restore completion message: %w", err)
}
return nil
}
func (r *RestoreNodesRegistry) SubscribeForRestoresCompletions(
handler func(nodeID uuid.UUID, restoreID uuid.UUID),
) error {
ctx := context.Background()
wrappedHandler := func(message string) {
var msg RestoreCompletionMessage
if err := json.Unmarshal([]byte(message), &msg); err != nil {
r.logger.Warn("Failed to unmarshal restore completion message", "error", err)
return
}
handler(msg.NodeID, msg.RestoreID)
}
err := r.pubsubCompletions.Subscribe(ctx, restoreCompletionChannel, wrappedHandler)
if err != nil {
return fmt.Errorf("failed to subscribe to restore completion channel: %w", err)
}
r.logger.Info("Subscribed to restore completion channel")
return nil
}
func (r *RestoreNodesRegistry) UnsubscribeForRestoresCompletions() error {
err := r.pubsubCompletions.Close()
if err != nil {
return fmt.Errorf("failed to unsubscribe from restore completion channel: %w", err)
}
r.logger.Info("Unsubscribed from restore completion channel")
return nil
}
func (r *RestoreNodesRegistry) extractNodeIDFromKey(key, prefix, suffix string) uuid.UUID {
nodeIDStr := strings.TrimPrefix(key, prefix)
nodeIDStr = strings.TrimSuffix(nodeIDStr, suffix)
nodeID, err := uuid.Parse(nodeIDStr)
if err != nil {
r.logger.Warn("Failed to parse node ID from key", "key", key, "error", err)
return uuid.Nil
}
return nodeID
}
func (r *RestoreNodesRegistry) pipelineGetKeys(keys []string) (map[string][]byte, error) {
if len(keys) == 0 {
return make(map[string][]byte), nil
}
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
commands := make([]valkey.Completed, 0, len(keys))
for _, key := range keys {
commands = append(commands, r.client.B().Get().Key(key).Build())
}
results := r.client.DoMulti(ctx, commands...)
keyDataMap := make(map[string][]byte, len(keys))
for i, result := range results {
if result.Error() != nil {
r.logger.Warn("Failed to get key in pipeline", "key", keys[i], "error", result.Error())
continue
}
data, err := result.AsBytes()
if err != nil {
r.logger.Warn("Failed to parse key data in pipeline", "key", keys[i], "error", err)
continue
}
keyDataMap[keys[i]] = data
}
return keyDataMap, nil
}
func (r *RestoreNodesRegistry) parseIntFromBytes(data []byte) (int64, error) {
str := string(data)
var count int64
_, err := fmt.Sscanf(str, "%d", &count)
if err != nil {
return 0, fmt.Errorf("failed to parse integer from bytes: %w", err)
}
return count, nil
}
func (r *RestoreNodesRegistry) cleanupDeadNodes() error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
var allKeys []string
cursor := uint64(0)
pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix
for {
result := r.client.Do(
ctx,
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(),
)
if result.Error() != nil {
return fmt.Errorf("failed to scan node keys: %w", result.Error())
}
scanResult, err := result.AsScanEntry()
if err != nil {
return fmt.Errorf("failed to parse scan result: %w", err)
}
allKeys = append(allKeys, scanResult.Elements...)
cursor = scanResult.Cursor
if cursor == 0 {
break
}
}
if len(allKeys) == 0 {
return nil
}
keyDataMap, err := r.pipelineGetKeys(allKeys)
if err != nil {
return fmt.Errorf("failed to pipeline get node keys: %w", err)
}
threshold := time.Now().UTC().Add(-deadNodeThreshold)
var deadNodeKeys []string
for key, data := range keyDataMap {
// Skip if the key doesn't exist (data is empty)
if len(data) == 0 {
continue
}
var node RestoreNode
if err := json.Unmarshal(data, &node); err != nil {
r.logger.Warn("Failed to unmarshal node data during cleanup", "key", key, "error", err)
continue
}
// Skip nodes with zero/uninitialized heartbeat
if node.LastHeartbeat.IsZero() {
continue
}
if node.LastHeartbeat.Before(threshold) {
nodeID := node.ID.String()
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeID, nodeInfoKeySuffix)
statsKey := fmt.Sprintf(
"%s%s%s",
nodeActiveRestoresPrefix,
nodeID,
nodeActiveRestoresSuffix,
)
deadNodeKeys = append(deadNodeKeys, infoKey, statsKey)
r.logger.Info(
"Marking node for cleanup",
"nodeID", nodeID,
"lastHeartbeat", node.LastHeartbeat,
"threshold", threshold,
)
}
}
if len(deadNodeKeys) == 0 {
return nil
}
delCtx, delCancel := context.WithTimeout(context.Background(), r.timeout)
defer delCancel()
result := r.client.Do(
delCtx,
r.client.B().Del().Key(deadNodeKeys...).Build(),
)
if result.Error() != nil {
return fmt.Errorf("failed to delete dead node keys: %w", result.Error())
}
deletedCount, err := result.AsInt64()
if err != nil {
return fmt.Errorf("failed to parse deleted count: %w", err)
}
r.logger.Info("Cleaned up dead nodes", "deletedKeysCount", deletedCount)
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,310 @@
package restoring
import (
"context"
"errors"
"fmt"
"log/slog"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
cache_utils "databasus-backend/internal/util/cache"
util_encryption "databasus-backend/internal/util/encryption"
)
const (
heartbeatTickerInterval = 15 * time.Second
restorerHealthcheckThreshold = 5 * time.Minute
)
type RestorerNode struct {
nodeID uuid.UUID
databaseService *databases.DatabaseService
backupService *backups.BackupService
fieldEncryptor util_encryption.FieldEncryptor
restoreRepository *restores_core.RestoreRepository
backupConfigService *backups_config.BackupConfigService
storageService *storages.StorageService
restoreNodesRegistry *RestoreNodesRegistry
logger *slog.Logger
restoreBackupUsecase restores_core.RestoreBackupUsecase
cacheUtil *cache_utils.CacheUtil[RestoreDatabaseCache]
restoreCancelManager *tasks_cancellation.TaskCancelManager
lastHeartbeat time.Time
runOnce sync.Once
hasRun atomic.Bool
}
func (n *RestorerNode) Run(ctx context.Context) {
wasAlreadyRun := n.hasRun.Load()
n.runOnce.Do(func() {
n.hasRun.Store(true)
n.lastHeartbeat = time.Now().UTC()
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
restoreNode := RestoreNode{
ID: n.nodeID,
ThroughputMBs: throughputMBs,
}
if err := n.restoreNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), restoreNode); err != nil {
n.logger.Error("Failed to register node in registry", "error", err)
panic(err)
}
restoreHandler := func(restoreID uuid.UUID, isCallNotifier bool) {
n.MakeRestore(restoreID)
if err := n.restoreNodesRegistry.PublishRestoreCompletion(n.nodeID, restoreID); err != nil {
n.logger.Error(
"Failed to publish restore completion",
"error",
err,
"restoreID",
restoreID,
)
}
}
err := n.restoreNodesRegistry.SubscribeNodeForRestoresAssignment(
n.nodeID,
restoreHandler,
)
if err != nil {
n.logger.Error("Failed to subscribe to restore assignments", "error", err)
panic(err)
}
defer func() {
if err := n.restoreNodesRegistry.UnsubscribeNodeForRestoresAssignments(); err != nil {
n.logger.Error("Failed to unsubscribe from restore assignments", "error", err)
}
}()
ticker := time.NewTicker(heartbeatTickerInterval)
defer ticker.Stop()
n.logger.Info("Restore node started", "nodeID", n.nodeID, "throughput", throughputMBs)
for {
select {
case <-ctx.Done():
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
if err := n.restoreNodesRegistry.UnregisterNodeFromRegistry(restoreNode); err != nil {
n.logger.Error("Failed to unregister node from registry", "error", err)
}
return
case <-ticker.C:
n.sendHeartbeat(&restoreNode)
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", n))
}
}
func (n *RestorerNode) IsRestorerRunning() bool {
return n.lastHeartbeat.After(time.Now().UTC().Add(-restorerHealthcheckThreshold))
}
func (n *RestorerNode) MakeRestore(restoreID uuid.UUID) {
// Get and delete cached DB credentials atomically
dbCache := n.cacheUtil.GetAndDelete(restoreID.String())
if dbCache == nil {
// Cache miss - fail immediately
restore, err := n.restoreRepository.FindByID(restoreID)
if err != nil {
n.logger.Error(
"Failed to get restore by ID after cache miss",
"restoreId",
restoreID,
"error",
err,
)
return
}
errMsg := "Database credentials expired or missing from cache (most likely due to instance restart)"
restore.FailMessage = &errMsg
restore.Status = restores_core.RestoreStatusFailed
if err := n.restoreRepository.Save(restore); err != nil {
n.logger.Error("Failed to save restore after cache miss", "error", err)
}
n.logger.Error("Restore failed: cache miss", "restoreId", restoreID)
return
}
restore, err := n.restoreRepository.FindByID(restoreID)
if err != nil {
n.logger.Error("Failed to get restore by ID", "restoreId", restoreID, "error", err)
return
}
backup, err := n.backupService.GetBackup(restore.BackupID)
if err != nil {
n.logger.Error("Failed to get backup by ID", "backupId", restore.BackupID, "error", err)
return
}
databaseID := backup.DatabaseID
database, err := n.databaseService.GetDatabaseByID(databaseID)
if err != nil {
n.logger.Error("Failed to get database by ID", "databaseId", databaseID, "error", err)
return
}
backupConfig, err := n.backupConfigService.GetBackupConfigByDbId(databaseID)
if err != nil {
n.logger.Error("Failed to get backup config by database ID", "error", err)
return
}
if backupConfig.StorageID == nil {
n.logger.Error("Backup config storage ID is not defined")
return
}
storage, err := n.storageService.GetStorageByID(*backupConfig.StorageID)
if err != nil {
n.logger.Error("Failed to get storage by ID", "error", err)
return
}
start := time.Now().UTC()
// Create cancellable context
ctx, cancel := context.WithCancel(context.Background())
n.restoreCancelManager.RegisterTask(restore.ID, cancel)
defer n.restoreCancelManager.UnregisterTask(restore.ID)
// Create restoring database from cached credentials
restoringToDB := &databases.Database{
Type: database.Type,
Postgresql: dbCache.PostgresqlDatabase,
Mysql: dbCache.MysqlDatabase,
Mariadb: dbCache.MariadbDatabase,
Mongodb: dbCache.MongodbDatabase,
}
if err := restoringToDB.PopulateDbData(n.logger, n.fieldEncryptor); err != nil {
errMsg := fmt.Sprintf("failed to auto-detect database data: %v", err)
restore.FailMessage = &errMsg
restore.Status = restores_core.RestoreStatusFailed
restore.RestoreDurationMs = time.Since(start).Milliseconds()
if err := n.restoreRepository.Save(restore); err != nil {
n.logger.Error("Failed to save restore", "error", err)
}
return
}
isExcludeExtensions := false
if dbCache.PostgresqlDatabase != nil {
isExcludeExtensions = dbCache.PostgresqlDatabase.IsExcludeExtensions
}
err = n.restoreBackupUsecase.Execute(
ctx,
backupConfig,
*restore,
database,
restoringToDB,
backup,
storage,
isExcludeExtensions,
)
if err != nil {
errMsg := err.Error()
// Check if restore was cancelled
isCancelled := strings.Contains(errMsg, "restore cancelled") ||
strings.Contains(errMsg, "context canceled") ||
errors.Is(err, context.Canceled)
isShutdown := strings.Contains(errMsg, "shutdown")
if isCancelled && !isShutdown {
n.logger.Warn("Restore was cancelled by user or system",
"restoreId", restore.ID,
"isCancelled", isCancelled,
"isShutdown", isShutdown,
)
restore.Status = restores_core.RestoreStatusCanceled
restore.RestoreDurationMs = time.Since(start).Milliseconds()
if err := n.restoreRepository.Save(restore); err != nil {
n.logger.Error("Failed to save cancelled restore", "error", err)
}
return
}
n.logger.Error("Restore execution failed",
"restoreId", restore.ID,
"backupId", backup.ID,
"databaseId", databaseID,
"databaseType", database.Type,
"storageId", storage.ID,
"storageType", storage.Type,
"error", err,
"errorMessage", errMsg,
)
restore.FailMessage = &errMsg
restore.Status = restores_core.RestoreStatusFailed
restore.RestoreDurationMs = time.Since(start).Milliseconds()
if err := n.restoreRepository.Save(restore); err != nil {
n.logger.Error("Failed to save restore", "error", err)
}
return
}
restore.Status = restores_core.RestoreStatusCompleted
restore.RestoreDurationMs = time.Since(start).Milliseconds()
if err := n.restoreRepository.Save(restore); err != nil {
n.logger.Error("Failed to save restore", "error", err)
return
}
n.logger.Info(
"Restore completed successfully",
"restoreId", restore.ID,
"backupId", backup.ID,
"durationMs", restore.RestoreDurationMs,
)
}
func (n *RestorerNode) sendHeartbeat(restoreNode *RestoreNode) {
n.lastHeartbeat = time.Now().UTC()
if err := n.restoreNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *restoreNode); err != nil {
n.logger.Error("Failed to send heartbeat", "error", err)
}
}

View File

@@ -0,0 +1,164 @@
package restoring
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/postgresql"
"databasus-backend/internal/features/notifiers"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
cache_utils "databasus-backend/internal/util/cache"
)
func Test_MakeRestore_WhenCacheMissed_RestoreFails(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
defer func() {
backupRepo := backups_core.BackupRepository{}
backupsList, _ := backupRepo.FindByDatabaseID(database.ID)
for _, backup := range backupsList {
backupRepo.DeleteByID(backup.ID)
}
restoreRepo := restores_core.RestoreRepository{}
restoresInProgress, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
for _, restore := range restoresInProgress {
restoreRepo.DeleteByID(restore.ID)
}
restoresFailed, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusFailed)
for _, restore := range restoresFailed {
restoreRepo.DeleteByID(restore.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
cache_utils.ClearAllCache()
}()
backup := backups.CreateTestBackup(database.ID, storage.ID)
// Create restore but DON'T cache DB credentials
// Also don't set embedded DB fields to avoid schema issues
restore := &restores_core.Restore{
BackupID: backup.ID,
Status: restores_core.RestoreStatusInProgress,
}
err := restoreRepository.Save(restore)
assert.NoError(t, err)
// Create restorer and execute restore (should fail due to cache miss)
restorerNode := CreateTestRestorerNode()
restorerNode.MakeRestore(restore.ID)
// Verify restore failed with appropriate error message
updatedRestore, err := restoreRepository.FindByID(restore.ID)
assert.NoError(t, err)
assert.Equal(t, restores_core.RestoreStatusFailed, updatedRestore.Status)
assert.NotNil(t, updatedRestore.FailMessage)
assert.Contains(
t,
*updatedRestore.FailMessage,
"Database credentials expired or missing from cache",
)
}
func Test_MakeRestore_WhenTaskStarts_CacheDeletedImmediately(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
defer func() {
backupRepo := backups_core.BackupRepository{}
backupsList, _ := backupRepo.FindByDatabaseID(database.ID)
for _, backup := range backupsList {
backupRepo.DeleteByID(backup.ID)
}
restoreRepo := restores_core.RestoreRepository{}
restoresInProgress, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
for _, restore := range restoresInProgress {
restoreRepo.DeleteByID(restore.ID)
}
restoresFailed, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusFailed)
for _, restore := range restoresFailed {
restoreRepo.DeleteByID(restore.ID)
}
restoresCompleted, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusCompleted)
for _, restore := range restoresCompleted {
restoreRepo.DeleteByID(restore.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
cache_utils.ClearAllCache()
}()
backup := backups.CreateTestBackup(database.ID, storage.ID)
// Create restore with cached DB credentials
// Don't set embedded DB fields in the restore model itself
restore := &restores_core.Restore{
BackupID: backup.ID,
Status: restores_core.RestoreStatusInProgress,
}
err := restoreRepository.Save(restore)
assert.NoError(t, err)
// Cache DB credentials separately
dbCache := &RestoreDatabaseCache{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Host: config.GetEnv().TestLocalhost,
Port: 5432,
Username: "test",
Password: "test",
Database: stringPtr("testdb"),
Version: "16",
},
}
restoreDatabaseCache.SetWithExpiration(restore.ID.String(), dbCache, 1*time.Hour)
// Verify cache exists before restore starts
cachedDB := restoreDatabaseCache.Get(restore.ID.String())
assert.NotNil(t, cachedDB, "Cache should exist before restore starts")
// Start restore (this will call GetAndDelete)
restorerNode := CreateTestRestorerNode()
restorerNode.MakeRestore(restore.ID)
// Verify cache was deleted immediately
cachedDBAfter := restoreDatabaseCache.Get(restore.ID.String())
assert.Nil(t, cachedDBAfter, "Cache should be deleted immediately when task starts")
}

View File

@@ -0,0 +1,410 @@
package restoring
import (
"context"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_config "databasus-backend/internal/features/backups/config"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
cache_utils "databasus-backend/internal/util/cache"
)
const (
schedulerStartupDelay = 1 * time.Minute
schedulerTickerInterval = 1 * time.Minute
schedulerHealthcheckThreshold = 5 * time.Minute
)
type RestoresScheduler struct {
restoreRepository *restores_core.RestoreRepository
backupService *backups.BackupService
storageService *storages.StorageService
backupConfigService *backups_config.BackupConfigService
restoreNodesRegistry *RestoreNodesRegistry
lastCheckTime time.Time
logger *slog.Logger
restoreToNodeRelations map[uuid.UUID]RestoreToNodeRelation
restorerNode *RestorerNode
cacheUtil *cache_utils.CacheUtil[RestoreDatabaseCache]
completionSubscriptionID uuid.UUID
runOnce sync.Once
hasRun atomic.Bool
}
func (s *RestoresScheduler) Run(ctx context.Context) {
wasAlreadyRun := s.hasRun.Load()
s.runOnce.Do(func() {
s.hasRun.Store(true)
s.lastCheckTime = time.Now().UTC()
if config.GetEnv().IsManyNodesMode {
// wait other nodes to start
time.Sleep(schedulerStartupDelay)
}
if err := s.failRestoresInProgress(); err != nil {
s.logger.Error("Failed to fail restores in progress", "error", err)
panic(err)
}
err := s.restoreNodesRegistry.SubscribeForRestoresCompletions(s.onRestoreCompleted)
if err != nil {
s.logger.Error("Failed to subscribe to restore completions", "error", err)
panic(err)
}
defer func() {
if err := s.restoreNodesRegistry.UnsubscribeForRestoresCompletions(); err != nil {
s.logger.Error("Failed to unsubscribe from restore completions", "error", err)
}
}()
if ctx.Err() != nil {
return
}
ticker := time.NewTicker(schedulerTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.checkDeadNodesAndFailRestores(); err != nil {
s.logger.Error("Failed to check dead nodes and fail restores", "error", err)
}
s.lastCheckTime = time.Now().UTC()
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
}
func (s *RestoresScheduler) IsSchedulerRunning() bool {
return s.lastCheckTime.After(time.Now().UTC().Add(-schedulerHealthcheckThreshold))
}
func (s *RestoresScheduler) failRestoresInProgress() error {
restoresInProgress, err := s.restoreRepository.FindByStatus(
restores_core.RestoreStatusInProgress,
)
if err != nil {
return err
}
for _, restore := range restoresInProgress {
failMessage := "Restore failed due to application restart"
restore.FailMessage = &failMessage
restore.Status = restores_core.RestoreStatusFailed
if err := s.restoreRepository.Save(restore); err != nil {
return err
}
}
return nil
}
func (s *RestoresScheduler) StartRestore(restoreID uuid.UUID, dbCache *RestoreDatabaseCache) error {
// If dbCache not provided, try to fetch from DB (for backward compatibility/testing)
if dbCache == nil {
restore, err := s.restoreRepository.FindByID(restoreID)
if err != nil {
s.logger.Error(
"Failed to find restore by ID",
"restoreId",
restoreID,
"error",
err,
)
return err
}
// Create cache DTO from restore (may be nil if not in DB)
dbCache = &RestoreDatabaseCache{
PostgresqlDatabase: restore.PostgresqlDatabase,
MysqlDatabase: restore.MysqlDatabase,
MariadbDatabase: restore.MariadbDatabase,
MongodbDatabase: restore.MongodbDatabase,
}
}
// Cache database credentials with 1-hour expiration
s.cacheUtil.SetWithExpiration(restoreID.String(), dbCache, 1*time.Hour)
leastBusyNodeID, err := s.calculateLeastBusyNode()
if err != nil {
s.logger.Error(
"Failed to calculate least busy node",
"restoreId",
restoreID,
"error",
err,
)
return err
}
if err := s.restoreNodesRegistry.IncrementRestoresInProgress(*leastBusyNodeID); err != nil {
s.logger.Error(
"Failed to increment restores in progress",
"nodeId",
leastBusyNodeID,
"restoreId",
restoreID,
"error",
err,
)
return err
}
if err := s.restoreNodesRegistry.AssignRestoreToNode(*leastBusyNodeID, restoreID, false); err != nil {
s.logger.Error(
"Failed to submit restore",
"nodeId",
leastBusyNodeID,
"restoreId",
restoreID,
"error",
err,
)
if decrementErr := s.restoreNodesRegistry.DecrementRestoresInProgress(*leastBusyNodeID); decrementErr != nil {
s.logger.Error(
"Failed to decrement restores in progress after submit failure",
"nodeId",
leastBusyNodeID,
"error",
decrementErr,
)
}
return err
}
if relation, exists := s.restoreToNodeRelations[*leastBusyNodeID]; exists {
relation.RestoreIDs = append(relation.RestoreIDs, restoreID)
s.restoreToNodeRelations[*leastBusyNodeID] = relation
} else {
s.restoreToNodeRelations[*leastBusyNodeID] = RestoreToNodeRelation{
NodeID: *leastBusyNodeID,
RestoreIDs: []uuid.UUID{restoreID},
}
}
s.logger.Info(
"Successfully triggered restore",
"restoreId",
restoreID,
"nodeId",
leastBusyNodeID,
)
return nil
}
func (s *RestoresScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
nodes, err := s.restoreNodesRegistry.GetAvailableNodes()
if err != nil {
return nil, fmt.Errorf("failed to get available nodes: %w", err)
}
if len(nodes) == 0 {
return nil, fmt.Errorf("no nodes available")
}
stats, err := s.restoreNodesRegistry.GetRestoreNodesStats()
if err != nil {
return nil, fmt.Errorf("failed to get restore nodes stats: %w", err)
}
statsMap := make(map[uuid.UUID]int)
for _, stat := range stats {
statsMap[stat.ID] = stat.ActiveRestores
}
var bestNode *RestoreNode
var bestScore float64 = -1
for i := range nodes {
node := &nodes[i]
activeRestores := statsMap[node.ID]
var score float64
if node.ThroughputMBs > 0 {
score = float64(activeRestores) / float64(node.ThroughputMBs)
} else {
score = float64(activeRestores) * 1000
}
if bestNode == nil || score < bestScore {
bestNode = node
bestScore = score
}
}
if bestNode == nil {
return nil, fmt.Errorf("no suitable nodes available")
}
return &bestNode.ID, nil
}
func (s *RestoresScheduler) onRestoreCompleted(nodeID uuid.UUID, restoreID uuid.UUID) {
// Verify this task is actually a restore (registry contains multiple task types)
_, err := s.restoreRepository.FindByID(restoreID)
if err != nil {
// Not a restore task, ignore it
return
}
relation, exists := s.restoreToNodeRelations[nodeID]
if !exists {
s.logger.Warn(
"Received completion for unknown node",
"nodeId",
nodeID,
"restoreId",
restoreID,
)
return
}
newRestoreIDs := make([]uuid.UUID, 0)
found := false
for _, id := range relation.RestoreIDs {
if id == restoreID {
found = true
continue
}
newRestoreIDs = append(newRestoreIDs, id)
}
if !found {
s.logger.Warn(
"Restore not found in node's restore list",
"nodeId",
nodeID,
"restoreId",
restoreID,
)
return
}
if len(newRestoreIDs) == 0 {
delete(s.restoreToNodeRelations, nodeID)
} else {
relation.RestoreIDs = newRestoreIDs
s.restoreToNodeRelations[nodeID] = relation
}
if err := s.restoreNodesRegistry.DecrementRestoresInProgress(nodeID); err != nil {
s.logger.Error(
"Failed to decrement restores in progress",
"nodeId",
nodeID,
"restoreId",
restoreID,
"error",
err,
)
}
}
func (s *RestoresScheduler) checkDeadNodesAndFailRestores() error {
nodes, err := s.restoreNodesRegistry.GetAvailableNodes()
if err != nil {
return fmt.Errorf("failed to get available nodes: %w", err)
}
aliveNodeIDs := make(map[uuid.UUID]bool)
for _, node := range nodes {
aliveNodeIDs[node.ID] = true
}
for nodeID, relation := range s.restoreToNodeRelations {
if aliveNodeIDs[nodeID] {
continue
}
s.logger.Warn(
"Node is dead, failing its restores",
"nodeId",
nodeID,
"restoreCount",
len(relation.RestoreIDs),
)
for _, restoreID := range relation.RestoreIDs {
restore, err := s.restoreRepository.FindByID(restoreID)
if err != nil {
s.logger.Error(
"Failed to find restore for dead node",
"nodeId",
nodeID,
"restoreId",
restoreID,
"error",
err,
)
continue
}
failMessage := "Restore failed due to node unavailability"
restore.FailMessage = &failMessage
restore.Status = restores_core.RestoreStatusFailed
if err := s.restoreRepository.Save(restore); err != nil {
s.logger.Error(
"Failed to save failed restore for dead node",
"nodeId",
nodeID,
"restoreId",
restoreID,
"error",
err,
)
continue
}
if err := s.restoreNodesRegistry.DecrementRestoresInProgress(nodeID); err != nil {
s.logger.Error(
"Failed to decrement restores in progress for dead node",
"nodeId",
nodeID,
"restoreId",
restoreID,
"error",
err,
)
}
s.logger.Info(
"Failed restore due to dead node",
"nodeId",
nodeID,
"restoreId",
restoreID,
)
}
delete(s.restoreToNodeRelations, nodeID)
}
return nil
}

View File

@@ -0,0 +1,856 @@
package restoring
import (
"testing"
"time"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/postgresql"
"databasus-backend/internal/features/notifiers"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/encryption"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_CheckDeadNodesAndFailRestores_NodeDies_FailsRestoreAndCleansUpRegistry(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
var mockNodeID uuid.UUID
defer func() {
backupRepo := backups_core.BackupRepository{}
backups, _ := backupRepo.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepo.DeleteByID(backup.ID)
}
restoreRepo := restores_core.RestoreRepository{}
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
for _, restore := range restores {
restoreRepo.DeleteByID(restore.ID)
}
restores, _ = restoreRepo.FindByStatus(restores_core.RestoreStatusFailed)
for _, restore := range restores {
restoreRepo.DeleteByID(restore.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
// Clean up mock node
if mockNodeID != uuid.Nil {
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: mockNodeID})
}
cache_utils.ClearAllCache()
}()
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
// Create a test backup
backup := backups.CreateTestBackup(database.ID, storage.ID)
var err error
// Register mock node without subscribing to restores (simulates node crash after registration)
mockNodeID = uuid.New()
err = CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC())
assert.NoError(t, err)
// Create restore and assign to mock node
restore := &restores_core.Restore{
BackupID: backup.ID,
Status: restores_core.RestoreStatusInProgress,
}
err = restoreRepository.Save(restore)
assert.NoError(t, err)
// Scheduler assigns restore to mock node
err = GetRestoresScheduler().StartRestore(restore.ID, nil)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
// Verify Valkey counter was incremented when restore was assigned
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
assert.NoError(t, err)
foundStat := false
for _, stat := range stats {
if stat.ID == mockNodeID {
assert.Equal(t, 1, stat.ActiveRestores)
foundStat = true
break
}
}
assert.True(t, foundStat, "Node stats should be present")
// Simulate node death by setting heartbeat older than 2-minute threshold
oldHeartbeat := time.Now().UTC().Add(-3 * time.Minute)
err = UpdateNodeHeartbeatDirectly(mockNodeID, 100, oldHeartbeat)
assert.NoError(t, err)
// Trigger dead node detection
err = GetRestoresScheduler().checkDeadNodesAndFailRestores()
assert.NoError(t, err)
// Verify restore was failed with appropriate error message
failedRestore, err := restoreRepository.FindByID(restore.ID)
assert.NoError(t, err)
assert.Equal(t, restores_core.RestoreStatusFailed, failedRestore.Status)
assert.NotNil(t, failedRestore.FailMessage)
assert.Contains(t, *failedRestore.FailMessage, "node unavailability")
// Verify Valkey counter was decremented after restore failed
stats, err = restoreNodesRegistry.GetRestoreNodesStats()
assert.NoError(t, err)
for _, stat := range stats {
if stat.ID == mockNodeID {
assert.Equal(t, 0, stat.ActiveRestores)
}
}
time.Sleep(200 * time.Millisecond)
}
func Test_OnRestoreCompleted_TaskIsNotRestore_SkipsProcessing(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
var mockNodeID uuid.UUID
defer func() {
backupRepo := backups_core.BackupRepository{}
backups, _ := backupRepo.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepo.DeleteByID(backup.ID)
}
restoreRepo := restores_core.RestoreRepository{}
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
for _, restore := range restores {
restoreRepo.DeleteByID(restore.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
// Clean up mock node
if mockNodeID != uuid.Nil {
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: mockNodeID})
}
cache_utils.ClearAllCache()
}()
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
// Create a test backup
backup := backups.CreateTestBackup(database.ID, storage.ID)
// Register mock node
mockNodeID = uuid.New()
err := CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC())
assert.NoError(t, err)
// Create restore and assign to the node
restore := &restores_core.Restore{
BackupID: backup.ID,
Status: restores_core.RestoreStatusInProgress,
}
err = restoreRepository.Save(restore)
assert.NoError(t, err)
err = GetRestoresScheduler().StartRestore(restore.ID, nil)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
// Get initial state of the registry
initialStats, err := restoreNodesRegistry.GetRestoreNodesStats()
assert.NoError(t, err)
var initialActiveTasks int
for _, stat := range initialStats {
if stat.ID == mockNodeID {
initialActiveTasks = stat.ActiveRestores
break
}
}
assert.Equal(t, 1, initialActiveTasks, "Should have 1 active task")
// Call onRestoreCompleted with a random UUID (not a restore ID)
nonRestoreTaskID := uuid.New()
GetRestoresScheduler().onRestoreCompleted(mockNodeID, nonRestoreTaskID)
time.Sleep(100 * time.Millisecond)
// Verify: Active tasks counter should remain the same (not decremented)
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
assert.NoError(t, err)
for _, stat := range stats {
if stat.ID == mockNodeID {
assert.Equal(t, initialActiveTasks, stat.ActiveRestores,
"Active tasks should not change for non-restore task")
}
}
// Verify: restore should still be in progress (not modified)
unchangedRestore, err := restoreRepository.FindByID(restore.ID)
assert.NoError(t, err)
assert.Equal(t, restores_core.RestoreStatusInProgress, unchangedRestore.Status,
"Restore status should not change for non-restore task completion")
// Verify: restoreToNodeRelations should still contain the node
scheduler := GetRestoresScheduler()
_, exists := scheduler.restoreToNodeRelations[mockNodeID]
assert.True(t, exists, "Node should still be in restoreToNodeRelations")
time.Sleep(200 * time.Millisecond)
}
func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
t.Run("Nodes with same throughput", func(t *testing.T) {
cache_utils.ClearAllCache()
node1ID := uuid.New()
node2ID := uuid.New()
node3ID := uuid.New()
now := time.Now().UTC()
defer func() {
// Clean up all mock nodes
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: node1ID})
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: node2ID})
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: node3ID})
cache_utils.ClearAllCache()
}()
err := CreateMockNodeInRegistry(node1ID, 100, now)
assert.NoError(t, err)
err = CreateMockNodeInRegistry(node2ID, 100, now)
assert.NoError(t, err)
err = CreateMockNodeInRegistry(node3ID, 100, now)
assert.NoError(t, err)
for range 5 {
err = restoreNodesRegistry.IncrementRestoresInProgress(node1ID)
assert.NoError(t, err)
}
for range 2 {
err = restoreNodesRegistry.IncrementRestoresInProgress(node2ID)
assert.NoError(t, err)
}
for range 8 {
err = restoreNodesRegistry.IncrementRestoresInProgress(node3ID)
assert.NoError(t, err)
}
leastBusyNodeID, err := GetRestoresScheduler().calculateLeastBusyNode()
assert.NoError(t, err)
assert.NotNil(t, leastBusyNodeID)
assert.Equal(t, node2ID, *leastBusyNodeID)
})
t.Run("Nodes with different throughput", func(t *testing.T) {
cache_utils.ClearAllCache()
node100MBsID := uuid.New()
node50MBsID := uuid.New()
now := time.Now().UTC()
defer func() {
// Clean up all mock nodes
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: node100MBsID})
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: node50MBsID})
cache_utils.ClearAllCache()
}()
err := CreateMockNodeInRegistry(node100MBsID, 100, now)
assert.NoError(t, err)
err = CreateMockNodeInRegistry(node50MBsID, 50, now)
assert.NoError(t, err)
for range 10 {
err = restoreNodesRegistry.IncrementRestoresInProgress(node100MBsID)
assert.NoError(t, err)
}
err = restoreNodesRegistry.IncrementRestoresInProgress(node50MBsID)
assert.NoError(t, err)
leastBusyNodeID, err := GetRestoresScheduler().calculateLeastBusyNode()
assert.NoError(t, err)
assert.NotNil(t, leastBusyNodeID)
assert.Equal(t, node50MBsID, *leastBusyNodeID)
})
}
func Test_FailRestoresInProgress_SchedulerStarts_UpdatesStatus(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backupRepo := backups_core.BackupRepository{}
backups, _ := backupRepo.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepo.DeleteByID(backup.ID)
}
restoreRepo := restores_core.RestoreRepository{}
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
for _, restore := range restores {
restoreRepo.DeleteByID(restore.ID)
}
restores, _ = restoreRepo.FindByStatus(restores_core.RestoreStatusFailed)
for _, restore := range restores {
restoreRepo.DeleteByID(restore.ID)
}
restores, _ = restoreRepo.FindByStatus(restores_core.RestoreStatusCompleted)
for _, restore := range restores {
restoreRepo.DeleteByID(restore.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
cache_utils.ClearAllCache()
}()
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
// Create a test backup
backup := backups.CreateTestBackup(database.ID, storage.ID)
// Create two in-progress restores that should be failed on scheduler restart
restore1 := &restores_core.Restore{
BackupID: backup.ID,
Status: restores_core.RestoreStatusInProgress,
CreatedAt: time.Now().UTC().Add(-30 * time.Minute),
}
err := restoreRepository.Save(restore1)
assert.NoError(t, err)
restore2 := &restores_core.Restore{
BackupID: backup.ID,
Status: restores_core.RestoreStatusInProgress,
CreatedAt: time.Now().UTC().Add(-15 * time.Minute),
}
err = restoreRepository.Save(restore2)
assert.NoError(t, err)
// Create a completed restore to verify it's not affected by failRestoresInProgress
completedRestore := &restores_core.Restore{
BackupID: backup.ID,
Status: restores_core.RestoreStatusCompleted,
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
}
err = restoreRepository.Save(completedRestore)
assert.NoError(t, err)
// Trigger the scheduler's failRestoresInProgress logic
// This should mark in-progress restores as failed
err = GetRestoresScheduler().failRestoresInProgress()
assert.NoError(t, err)
// Verify all restores exist and were processed correctly
allRestores1, err := restoreRepository.FindByID(restore1.ID)
assert.NoError(t, err)
allRestores2, err := restoreRepository.FindByID(restore2.ID)
assert.NoError(t, err)
allRestores3, err := restoreRepository.FindByID(completedRestore.ID)
assert.NoError(t, err)
var failedCount int
var completedCount int
restoresToCheck := []*restores_core.Restore{allRestores1, allRestores2, allRestores3}
for _, restore := range restoresToCheck {
switch restore.Status {
case restores_core.RestoreStatusFailed:
failedCount++
// Verify fail message indicates application restart
assert.NotNil(t, restore.FailMessage)
assert.Equal(t, "Restore failed due to application restart", *restore.FailMessage)
case restores_core.RestoreStatusCompleted:
completedCount++
}
}
// Verify correct number of restores in each state
assert.Equal(t, 2, failedCount, "Should have 2 failed restores (originally in progress)")
assert.Equal(t, 1, completedCount, "Should have 1 completed restore (unchanged)")
time.Sleep(200 * time.Millisecond)
}
func Test_StartRestore_RestoreCompletes_DecrementsActiveTaskCount(t *testing.T) {
cache_utils.ClearAllCache()
// Start scheduler so it can handle task completions
scheduler := CreateTestRestoresScheduler()
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
restorerNode := CreateTestRestorerNode()
restorerNode.restoreBackupUsecase = &MockSuccessRestoreUsecase{}
cancel := StartRestorerNodeForTest(t, restorerNode)
defer StopRestorerNodeForTest(t, cancel, restorerNode)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backupRepo := backups_core.BackupRepository{}
backups, _ := backupRepo.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepo.DeleteByID(backup.ID)
}
restoreRepo := restores_core.RestoreRepository{}
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusCompleted)
for _, restore := range restores {
restoreRepo.DeleteByID(restore.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
// Create a test backup
backup := backups.CreateTestBackup(database.ID, storage.ID)
// Get initial active task count
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
assert.NoError(t, err)
var initialActiveTasks int
for _, stat := range stats {
if stat.ID == restorerNode.nodeID {
initialActiveTasks = stat.ActiveRestores
break
}
}
t.Logf("Initial active tasks: %d", initialActiveTasks)
// Create and start restore
restore := &restores_core.Restore{
BackupID: backup.ID,
Status: restores_core.RestoreStatusInProgress,
}
err = restoreRepository.Save(restore)
assert.NoError(t, err)
err = scheduler.StartRestore(restore.ID, nil)
assert.NoError(t, err)
// Wait for restore to complete
WaitForRestoreCompletion(t, restore.ID, 10*time.Second)
// Verify restore was completed
completedRestore, err := restoreRepository.FindByID(restore.ID)
assert.NoError(t, err)
assert.Equal(t, restores_core.RestoreStatusCompleted, completedRestore.Status)
// Wait for active task count to decrease
decreased := WaitForActiveTasksDecrease(
t,
restorerNode.nodeID,
initialActiveTasks+1,
10*time.Second,
)
assert.True(t, decreased, "Active task count should have decreased after restore completion")
// Verify final active task count equals initial count
finalStats, err := restoreNodesRegistry.GetRestoreNodesStats()
assert.NoError(t, err)
for _, stat := range finalStats {
if stat.ID == restorerNode.nodeID {
t.Logf("Final active tasks: %d", stat.ActiveRestores)
assert.Equal(t, initialActiveTasks, stat.ActiveRestores,
"Active task count should return to initial value after restore completion")
break
}
}
time.Sleep(200 * time.Millisecond)
}
func Test_StartRestore_RestoreFails_DecrementsActiveTaskCount(t *testing.T) {
cache_utils.ClearAllCache()
// Start scheduler so it can handle task completions
scheduler := CreateTestRestoresScheduler()
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
restorerNode := CreateTestRestorerNode()
restorerNode.restoreBackupUsecase = &MockFailedRestoreUsecase{}
cancel := StartRestorerNodeForTest(t, restorerNode)
defer StopRestorerNodeForTest(t, cancel, restorerNode)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backupRepo := backups_core.BackupRepository{}
backups, _ := backupRepo.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepo.DeleteByID(backup.ID)
}
restoreRepo := restores_core.RestoreRepository{}
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusFailed)
for _, restore := range restores {
restoreRepo.DeleteByID(restore.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
// Create a test backup
backup := backups.CreateTestBackup(database.ID, storage.ID)
// Get initial active task count
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
assert.NoError(t, err)
var initialActiveTasks int
for _, stat := range stats {
if stat.ID == restorerNode.nodeID {
initialActiveTasks = stat.ActiveRestores
break
}
}
t.Logf("Initial active tasks: %d", initialActiveTasks)
// Create and start restore
restore := &restores_core.Restore{
BackupID: backup.ID,
Status: restores_core.RestoreStatusInProgress,
}
err = restoreRepository.Save(restore)
assert.NoError(t, err)
err = scheduler.StartRestore(restore.ID, nil)
assert.NoError(t, err)
// Wait for restore to fail
WaitForRestoreCompletion(t, restore.ID, 10*time.Second)
// Verify restore failed
failedRestore, err := restoreRepository.FindByID(restore.ID)
assert.NoError(t, err)
assert.Equal(t, restores_core.RestoreStatusFailed, failedRestore.Status)
// Wait for active task count to decrease
decreased := WaitForActiveTasksDecrease(
t,
restorerNode.nodeID,
initialActiveTasks+1,
10*time.Second,
)
assert.True(t, decreased, "Active task count should have decreased after restore failure")
// Verify final active task count equals initial count
finalStats, err := restoreNodesRegistry.GetRestoreNodesStats()
assert.NoError(t, err)
for _, stat := range finalStats {
if stat.ID == restorerNode.nodeID {
t.Logf("Final active tasks: %d", stat.ActiveRestores)
assert.Equal(t, initialActiveTasks, stat.ActiveRestores,
"Active task count should return to initial value after restore failure")
break
}
}
time.Sleep(200 * time.Millisecond)
}
func Test_StartRestore_CredentialsStoredEncryptedInCache(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
var mockNodeID uuid.UUID
defer func() {
backupRepo := backups_core.BackupRepository{}
backups, _ := backupRepo.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepo.DeleteByID(backup.ID)
}
restoreRepo := restores_core.RestoreRepository{}
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
for _, restore := range restores {
restoreRepo.DeleteByID(restore.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
// Clean up mock node
if mockNodeID != uuid.Nil {
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: mockNodeID})
}
cache_utils.ClearAllCache()
}()
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
// Create a test backup
backup := backups.CreateTestBackup(database.ID, storage.ID)
// Register mock node so scheduler can assign restore to it
mockNodeID = uuid.New()
err := CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC())
assert.NoError(t, err)
// Create restore with plaintext credentials
plaintextPassword := "test_password_123"
restore := &restores_core.Restore{
BackupID: backup.ID,
Status: restores_core.RestoreStatusInProgress,
}
err = restoreRepository.Save(restore)
assert.NoError(t, err)
// Create PostgreSQL database credentials with plaintext password
postgresDB := &postgresql.PostgresqlDatabase{
Host: config.GetEnv().TestLocalhost,
Port: 5432,
Username: "testuser",
Password: plaintextPassword,
Database: stringPtr("testdb"),
Version: "16",
}
// Encrypt password using FieldEncryptor (same as production flow)
encryptor := encryption.GetFieldEncryptor()
err = postgresDB.EncryptSensitiveFields(database.ID, encryptor)
assert.NoError(t, err)
// Verify password was encrypted (different from plaintext)
assert.NotEqual(t, plaintextPassword, postgresDB.Password,
"Password should be encrypted, not plaintext")
// Create cache with encrypted credentials
dbCache := &RestoreDatabaseCache{
PostgresqlDatabase: postgresDB,
}
// Call StartRestore to cache credentials (do NOT start restore node)
err = GetRestoresScheduler().StartRestore(restore.ID, dbCache)
assert.NoError(t, err)
// Directly read from cache
cachedData := restoreDatabaseCache.Get(restore.ID.String())
assert.NotNil(t, cachedData, "Cache entry should exist")
assert.NotNil(t, cachedData.PostgresqlDatabase, "PostgreSQL credentials should be cached")
// Verify password in cache is encrypted (not plaintext)
assert.NotEqual(t, plaintextPassword, cachedData.PostgresqlDatabase.Password,
"Cached password should be encrypted, not plaintext")
assert.Equal(t, postgresDB.Password, cachedData.PostgresqlDatabase.Password,
"Cached password should match the encrypted version")
// Verify other fields are present
assert.Equal(t, config.GetEnv().TestLocalhost, cachedData.PostgresqlDatabase.Host)
assert.Equal(t, 5432, cachedData.PostgresqlDatabase.Port)
assert.Equal(t, "testuser", cachedData.PostgresqlDatabase.Username)
assert.Equal(t, "testdb", *cachedData.PostgresqlDatabase.Database)
time.Sleep(200 * time.Millisecond)
}
func Test_StartRestore_CredentialsRemovedAfterRestoreStarts(t *testing.T) {
cache_utils.ClearAllCache()
// Start scheduler so it can handle task assignments
scheduler := CreateTestRestoresScheduler()
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
// Create mock restorer node with credential capture usecase
restorerNode := CreateTestRestorerNode()
calledChan := make(chan *databases.Database, 1)
restorerNode.restoreBackupUsecase = &MockCaptureCredentialsRestoreUsecase{
CalledChan: calledChan,
ShouldSucceed: true,
}
cancel := StartRestorerNodeForTest(t, restorerNode)
defer StopRestorerNodeForTest(t, cancel, restorerNode)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backupRepo := backups_core.BackupRepository{}
backups, _ := backupRepo.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepo.DeleteByID(backup.ID)
}
restoreRepo := restores_core.RestoreRepository{}
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusCompleted)
for _, restore := range restores {
restoreRepo.DeleteByID(restore.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
cache_utils.ClearAllCache()
}()
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
// Create a test backup
backup := backups.CreateTestBackup(database.ID, storage.ID)
// Create restore with credentials
plaintextPassword := "test_password_456"
restore := &restores_core.Restore{
BackupID: backup.ID,
Status: restores_core.RestoreStatusInProgress,
}
err := restoreRepository.Save(restore)
assert.NoError(t, err)
// Create PostgreSQL database credentials
// Database field is nil to avoid PopulateDbData trying to connect
postgresDB := &postgresql.PostgresqlDatabase{
Host: config.GetEnv().TestLocalhost,
Port: 5432,
Username: "testuser",
Password: plaintextPassword,
Database: nil,
Version: "16",
}
// Encrypt password (same as production flow)
encryptor := encryption.GetFieldEncryptor()
err = postgresDB.EncryptSensitiveFields(database.ID, encryptor)
assert.NoError(t, err)
encryptedPassword := postgresDB.Password
// Create cache with encrypted credentials
dbCache := &RestoreDatabaseCache{
PostgresqlDatabase: postgresDB,
}
// Call StartRestore to cache credentials and trigger restore
err = scheduler.StartRestore(restore.ID, dbCache)
assert.NoError(t, err)
// Wait for mock usecase to be called (with timeout)
var capturedDB *databases.Database
select {
case capturedDB = <-calledChan:
t.Log("Mock usecase was called, credentials captured")
case <-time.After(10 * time.Second):
t.Fatal("Timeout waiting for mock usecase to be called")
}
// Verify cache is empty after restore starts (credentials were deleted)
cacheAfterExecution := restoreDatabaseCache.Get(restore.ID.String())
assert.Nil(t, cacheAfterExecution, "Cache should be empty after restore execution starts")
// Verify mock received valid credentials
assert.NotNil(t, capturedDB, "Captured database should not be nil")
assert.NotNil(t, capturedDB.Postgresql, "PostgreSQL credentials should be provided to usecase")
assert.Equal(t, config.GetEnv().TestLocalhost, capturedDB.Postgresql.Host)
assert.Equal(t, 5432, capturedDB.Postgresql.Port)
assert.Equal(t, "testuser", capturedDB.Postgresql.Username)
assert.NotEmpty(t, capturedDB.Postgresql.Password, "Password should be provided to usecase")
// Note: Password at this point may still be encrypted because PopulateDbData
// is called after the mock captures it. The important thing is that credentials
// were provided to the usecase despite cache being deleted.
t.Logf("Encrypted password in cache: %s", encryptedPassword)
t.Logf("Password received by usecase: %s", capturedDB.Postgresql.Password)
// Wait for restore to complete
WaitForRestoreCompletion(t, restore.ID, 10*time.Second)
// Verify restore was completed
completedRestore, err := restoreRepository.FindByID(restore.ID)
assert.NoError(t, err)
assert.Equal(t, restores_core.RestoreStatusCompleted, completedRestore.Status)
time.Sleep(200 * time.Millisecond)
}

View File

@@ -0,0 +1,342 @@
package restoring
import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/postgresql"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/restores/usecases"
"databasus-backend/internal/features/storages"
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
)
func CreateTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
backups_config.GetBackupConfigController(),
)
return router
}
func CreateTestRestorerNode() *RestorerNode {
return &RestorerNode{
nodeID: uuid.New(),
databaseService: databases.GetDatabaseService(),
backupService: backups.GetBackupService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
restoreRepository: restoreRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
restoreNodesRegistry: restoreNodesRegistry,
logger: logger.GetLogger(),
restoreBackupUsecase: usecases.GetRestoreBackupUsecase(),
cacheUtil: restoreDatabaseCache,
restoreCancelManager: tasks_cancellation.GetTaskCancelManager(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}
func CreateTestRestorerNodeWithUsecase(usecase restores_core.RestoreBackupUsecase) *RestorerNode {
return &RestorerNode{
nodeID: uuid.New(),
databaseService: databases.GetDatabaseService(),
backupService: backups.GetBackupService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
restoreRepository: restoreRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
restoreNodesRegistry: restoreNodesRegistry,
logger: logger.GetLogger(),
restoreBackupUsecase: usecase,
cacheUtil: restoreDatabaseCache,
restoreCancelManager: tasks_cancellation.GetTaskCancelManager(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}
func CreateTestRestoresScheduler() *RestoresScheduler {
return &RestoresScheduler{
restoreRepository,
backups.GetBackupService(),
storages.GetStorageService(),
backups_config.GetBackupConfigService(),
restoreNodesRegistry,
time.Now().UTC(),
logger.GetLogger(),
make(map[uuid.UUID]RestoreToNodeRelation),
restorerNode,
restoreDatabaseCache,
uuid.Nil,
sync.Once{},
atomic.Bool{},
}
}
// WaitForRestoreCompletion waits for a restore to be completed (or failed)
func WaitForRestoreCompletion(
t *testing.T,
restoreID uuid.UUID,
timeout time.Duration,
) {
deadline := time.Now().UTC().Add(timeout)
for time.Now().UTC().Before(deadline) {
restore, err := restoreRepository.FindByID(restoreID)
if err != nil {
t.Logf("WaitForRestoreCompletion: error finding restore: %v", err)
time.Sleep(50 * time.Millisecond)
continue
}
t.Logf("WaitForRestoreCompletion: restore status: %s", restore.Status)
if restore.Status == restores_core.RestoreStatusCompleted ||
restore.Status == restores_core.RestoreStatusFailed {
t.Logf(
"WaitForRestoreCompletion: restore finished with status %s",
restore.Status,
)
return
}
time.Sleep(50 * time.Millisecond)
}
t.Logf("WaitForRestoreCompletion: timeout waiting for restore to complete")
}
// StartRestorerNodeForTest starts a RestorerNode in a goroutine for testing.
// The node registers itself in the registry and subscribes to restore assignments.
// Returns a context cancel function that should be deferred to stop the node.
func StartRestorerNodeForTest(t *testing.T, restorerNode *RestorerNode) context.CancelFunc {
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
restorerNode.Run(ctx)
close(done)
}()
// Poll registry for node presence instead of fixed sleep
deadline := time.Now().UTC().Add(5 * time.Second)
for time.Now().UTC().Before(deadline) {
nodes, err := restoreNodesRegistry.GetAvailableNodes()
if err == nil {
for _, node := range nodes {
if node.ID == restorerNode.nodeID {
t.Logf("RestorerNode registered in registry: %s", restorerNode.nodeID)
return func() {
cancel()
select {
case <-done:
t.Log("RestorerNode stopped gracefully")
case <-time.After(2 * time.Second):
t.Log("RestorerNode stop timeout")
}
}
}
}
}
time.Sleep(50 * time.Millisecond)
}
t.Fatalf("RestorerNode failed to register in registry within timeout")
return nil
}
// StartSchedulerForTest starts the RestoresScheduler in a goroutine for testing.
// The scheduler subscribes to task completions and manages restore lifecycle.
// Returns a context cancel function that should be deferred to stop the scheduler.
func StartSchedulerForTest(t *testing.T, scheduler *RestoresScheduler) context.CancelFunc {
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
scheduler.Run(ctx)
close(done)
}()
// Give scheduler time to subscribe to completions
time.Sleep(100 * time.Millisecond)
t.Log("RestoresScheduler started")
return func() {
cancel()
select {
case <-done:
t.Log("RestoresScheduler stopped gracefully")
case <-time.After(2 * time.Second):
t.Log("RestoresScheduler stop timeout")
}
}
}
// StopRestorerNodeForTest stops the RestorerNode by canceling its context.
// It waits for the node to unregister from the registry.
func StopRestorerNodeForTest(t *testing.T, cancel context.CancelFunc, restorerNode *RestorerNode) {
cancel()
// Wait for node to unregister from registry
deadline := time.Now().UTC().Add(2 * time.Second)
for time.Now().UTC().Before(deadline) {
nodes, err := restoreNodesRegistry.GetAvailableNodes()
if err == nil {
found := false
for _, node := range nodes {
if node.ID == restorerNode.nodeID {
found = true
break
}
}
if !found {
t.Logf("RestorerNode unregistered from registry: %s", restorerNode.nodeID)
return
}
}
time.Sleep(50 * time.Millisecond)
}
t.Logf("RestorerNode stop completed for %s", restorerNode.nodeID)
}
func CreateMockNodeInRegistry(nodeID uuid.UUID, throughputMBs int, lastHeartbeat time.Time) error {
restoreNode := RestoreNode{
ID: nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: lastHeartbeat,
}
return restoreNodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, restoreNode)
}
func UpdateNodeHeartbeatDirectly(
nodeID uuid.UUID,
throughputMBs int,
lastHeartbeat time.Time,
) error {
restoreNode := RestoreNode{
ID: nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: lastHeartbeat,
}
return restoreNodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, restoreNode)
}
func GetNodeFromRegistry(nodeID uuid.UUID) (*RestoreNode, error) {
nodes, err := restoreNodesRegistry.GetAvailableNodes()
if err != nil {
return nil, err
}
for _, node := range nodes {
if node.ID == nodeID {
return &node, nil
}
}
return nil, fmt.Errorf("node not found")
}
// WaitForActiveTasksDecrease waits for the active task count to decrease below the initial count.
// It polls the registry every 500ms until the count decreases or the timeout is reached.
// Returns true if the count decreased, false if timeout was reached.
func WaitForActiveTasksDecrease(
t *testing.T,
nodeID uuid.UUID,
initialCount int,
timeout time.Duration,
) bool {
deadline := time.Now().UTC().Add(timeout)
for time.Now().UTC().Before(deadline) {
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
if err != nil {
t.Logf("WaitForActiveTasksDecrease: error getting node stats: %v", err)
time.Sleep(500 * time.Millisecond)
continue
}
for _, stat := range stats {
if stat.ID == nodeID {
t.Logf(
"WaitForActiveTasksDecrease: current active tasks = %d (initial = %d)",
stat.ActiveRestores,
initialCount,
)
if stat.ActiveRestores < initialCount {
t.Logf(
"WaitForActiveTasksDecrease: active tasks decreased from %d to %d",
initialCount,
stat.ActiveRestores,
)
return true
}
break
}
}
time.Sleep(500 * time.Millisecond)
}
t.Logf("WaitForActiveTasksDecrease: timeout waiting for active tasks to decrease")
return false
}
// CreateTestRestore creates a test restore with the given backup and status
func CreateTestRestore(
t *testing.T,
backup *backups_core.Backup,
status restores_core.RestoreStatus,
) *restores_core.Restore {
restore := &restores_core.Restore{
BackupID: backup.ID,
Status: status,
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Host: config.GetEnv().TestLocalhost,
Port: 5432,
Username: "test",
Password: "test",
Database: stringPtr("testdb"),
Version: "16",
},
}
err := restoreRepository.Save(restore)
if err != nil {
t.Fatalf("Failed to create test restore: %v", err)
}
return restore
}
func stringPtr(s string) *string {
return &s
}

View File

@@ -7,10 +7,11 @@ import (
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/disk"
"databasus-backend/internal/features/restores/enums"
"databasus-backend/internal/features/restores/models"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/restores/restoring"
"databasus-backend/internal/features/restores/usecases"
"databasus-backend/internal/features/storages"
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
users_models "databasus-backend/internal/features/users/models"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/encryption"
@@ -25,7 +26,7 @@ import (
type RestoreService struct {
backupService *backups.BackupService
restoreRepository *RestoreRepository
restoreRepository *restores_core.RestoreRepository
storageService *storages.StorageService
backupConfigService *backups_config.BackupConfigService
restoreBackupUsecase *usecases.RestoreBackupUsecase
@@ -35,6 +36,7 @@ type RestoreService struct {
auditLogService *audit_logs.AuditLogService
fieldEncryptor encryption.FieldEncryptor
diskService *disk.DiskService
taskCancelManager *tasks_cancellation.TaskCancelManager
}
func (s *RestoreService) OnBeforeBackupRemove(backup *backups_core.Backup) error {
@@ -44,7 +46,7 @@ func (s *RestoreService) OnBeforeBackupRemove(backup *backups_core.Backup) error
}
for _, restore := range restores {
if restore.Status == enums.RestoreStatusInProgress {
if restore.Status == restores_core.RestoreStatusInProgress {
return errors.New("restore is in progress, backup cannot be removed")
}
}
@@ -61,7 +63,7 @@ func (s *RestoreService) OnBeforeBackupRemove(backup *backups_core.Backup) error
func (s *RestoreService) GetRestores(
user *users_models.User,
backupID uuid.UUID,
) ([]*models.Restore, error) {
) ([]*restores_core.Restore, error) {
backup, err := s.backupService.GetBackup(backupID)
if err != nil {
return nil, err
@@ -93,7 +95,7 @@ func (s *RestoreService) GetRestores(
func (s *RestoreService) RestoreBackupWithAuth(
user *users_models.User,
backupID uuid.UUID,
requestDTO RestoreBackupRequest,
requestDTO restores_core.RestoreBackupRequest,
) error {
backup, err := s.backupService.GetBackup(backupID)
if err != nil {
@@ -134,11 +136,50 @@ func (s *RestoreService) RestoreBackupWithAuth(
return err
}
go func() {
if err := s.RestoreBackup(backup, requestDTO); err != nil {
s.logger.Error("Failed to restore backup", "error", err)
// Validate no parallel restores for the same database
if err := s.validateNoParallelRestores(backup.DatabaseID); err != nil {
return err
}
// Create restore record with the request configuration
restore := restores_core.Restore{
ID: uuid.New(),
Status: restores_core.RestoreStatusInProgress,
BackupID: backup.ID,
Backup: backup,
CreatedAt: time.Now().UTC(),
RestoreDurationMs: 0,
FailMessage: nil,
PostgresqlDatabase: requestDTO.PostgresqlDatabase,
MysqlDatabase: requestDTO.MysqlDatabase,
MariadbDatabase: requestDTO.MariadbDatabase,
MongodbDatabase: requestDTO.MongodbDatabase,
}
if err := s.restoreRepository.Save(&restore); err != nil {
return err
}
// Prepare database cache with credentials from the request
dbCache := &restoring.RestoreDatabaseCache{
PostgresqlDatabase: requestDTO.PostgresqlDatabase,
MysqlDatabase: requestDTO.MysqlDatabase,
MariadbDatabase: requestDTO.MariadbDatabase,
MongodbDatabase: requestDTO.MongodbDatabase,
}
// Trigger restore via scheduler
scheduler := restoring.GetRestoresScheduler()
if err := scheduler.StartRestore(restore.ID, dbCache); err != nil {
// Mark restore as failed if we can't schedule it
failMsg := fmt.Sprintf("Failed to schedule restore: %v", err)
restore.FailMessage = &failMsg
restore.Status = restores_core.RestoreStatusFailed
if saveErr := s.restoreRepository.Save(&restore); saveErr != nil {
s.logger.Error("Failed to save restore after scheduling error", "error", saveErr)
}
}()
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
@@ -153,127 +194,9 @@ func (s *RestoreService) RestoreBackupWithAuth(
return nil
}
func (s *RestoreService) RestoreBackup(
backup *backups_core.Backup,
requestDTO RestoreBackupRequest,
) error {
if backup.Status != backups_core.BackupStatusCompleted {
return errors.New("backup is not completed")
}
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
if err != nil {
return err
}
switch database.Type {
case databases.DatabaseTypePostgres:
if requestDTO.PostgresqlDatabase == nil {
return errors.New("postgresql database is required")
}
case databases.DatabaseTypeMysql:
if requestDTO.MysqlDatabase == nil {
return errors.New("mysql database is required")
}
case databases.DatabaseTypeMariadb:
if requestDTO.MariadbDatabase == nil {
return errors.New("mariadb database is required")
}
case databases.DatabaseTypeMongodb:
if requestDTO.MongodbDatabase == nil {
return errors.New("mongodb database is required")
}
}
restore := models.Restore{
ID: uuid.New(),
Status: enums.RestoreStatusInProgress,
BackupID: backup.ID,
Backup: backup,
CreatedAt: time.Now().UTC(),
RestoreDurationMs: 0,
FailMessage: nil,
}
// Save the restore first
if err := s.restoreRepository.Save(&restore); err != nil {
return err
}
// Save the restore again to include the postgresql database
if err := s.restoreRepository.Save(&restore); err != nil {
return err
}
storage, err := s.storageService.GetStorageByID(backup.StorageID)
if err != nil {
return err
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(
database.ID,
)
if err != nil {
return err
}
start := time.Now().UTC()
restoringToDB := &databases.Database{
Type: database.Type,
Postgresql: requestDTO.PostgresqlDatabase,
Mysql: requestDTO.MysqlDatabase,
Mariadb: requestDTO.MariadbDatabase,
Mongodb: requestDTO.MongodbDatabase,
}
if err := restoringToDB.PopulateDbData(s.logger, s.fieldEncryptor); err != nil {
return fmt.Errorf("failed to auto-detect database data: %w", err)
}
isExcludeExtensions := false
if requestDTO.PostgresqlDatabase != nil {
isExcludeExtensions = requestDTO.PostgresqlDatabase.IsExcludeExtensions
}
err = s.restoreBackupUsecase.Execute(
backupConfig,
restore,
database,
restoringToDB,
backup,
storage,
isExcludeExtensions,
)
if err != nil {
errMsg := err.Error()
restore.FailMessage = &errMsg
restore.Status = enums.RestoreStatusFailed
restore.RestoreDurationMs = time.Since(start).Milliseconds()
if err := s.restoreRepository.Save(&restore); err != nil {
return err
}
return err
}
restore.Status = enums.RestoreStatusCompleted
restore.RestoreDurationMs = time.Since(start).Milliseconds()
if err := s.restoreRepository.Save(&restore); err != nil {
return err
}
return nil
}
func (s *RestoreService) validateVersionCompatibility(
backupDatabase *databases.Database,
requestDTO RestoreBackupRequest,
requestDTO restores_core.RestoreBackupRequest,
) error {
// populate version
if requestDTO.MariadbDatabase != nil {
@@ -372,7 +295,7 @@ func (s *RestoreService) validateVersionCompatibility(
func (s *RestoreService) validateDiskSpace(
backup *backups_core.Backup,
requestDTO RestoreBackupRequest,
requestDTO restores_core.RestoreBackupRequest,
) error {
// Only validate disk space for PostgreSQL when file-based restore is needed:
// - CPU > 1 (parallel jobs require file)
@@ -424,3 +347,71 @@ func (s *RestoreService) validateDiskSpace(
return nil
}
func (s *RestoreService) validateNoParallelRestores(databaseID uuid.UUID) error {
inProgressRestores, err := s.restoreRepository.FindInProgressRestoresByDatabaseID(databaseID)
if err != nil {
return fmt.Errorf("failed to check for in-progress restores: %w", err)
}
isInProgress := len(inProgressRestores) > 0
if isInProgress {
return errors.New(
"another restore is already in progress for this database. Please wait for it to complete or cancel it before starting a new restore",
)
}
return nil
}
func (s *RestoreService) CancelRestore(
user *users_models.User,
restoreID uuid.UUID,
) error {
restore, err := s.restoreRepository.FindByID(restoreID)
if err != nil {
return err
}
backup, err := s.backupService.GetBackup(restore.BackupID)
if err != nil {
return err
}
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
if err != nil {
return err
}
if database.WorkspaceID == nil {
return errors.New("cannot cancel restore for database without workspace")
}
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user)
if err != nil {
return err
}
if !canManage {
return errors.New("insufficient permissions to cancel restore for this database")
}
if restore.Status != restores_core.RestoreStatusInProgress {
return errors.New("restore is not in progress")
}
if err := s.taskCancelManager.CancelTask(restoreID); err != nil {
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Restore cancelled for database: %s (ID: %s)",
database.Name,
restoreID.String(),
),
&user.ID,
database.WorkspaceID,
)
return nil
}

View File

@@ -0,0 +1,51 @@
package restores
import (
"context"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"databasus-backend/internal/features/backups/backups"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/restores/restoring"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
)
func CreateTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
backups_config.GetBackupConfigController(),
backups.GetBackupController(),
GetRestoreController(),
)
v1 := router.Group("/api/v1")
backups.GetBackupController().RegisterPublicRoutes(v1)
return router
}
func SetupMockRestoreNode(t *testing.T) (uuid.UUID, context.CancelFunc) {
nodeID := uuid.New()
err := restoring.CreateMockNodeInRegistry(
nodeID,
100,
time.Now().UTC(),
)
if err != nil {
t.Fatalf("Failed to create mock node: %v", err)
}
cleanup := func() {
// Node will expire naturally from registry
}
return nodeID, cleanup
}

View File

@@ -24,7 +24,7 @@ import (
"databasus-backend/internal/features/databases"
mariadbtypes "databasus-backend/internal/features/databases/databases/mariadb"
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
"databasus-backend/internal/features/restores/models"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
util_encryption "databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/tools"
@@ -36,10 +36,11 @@ type RestoreMariadbBackupUsecase struct {
}
func (uc *RestoreMariadbBackupUsecase) Execute(
parentCtx context.Context,
originalDB *databases.Database,
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
restore models.Restore,
restore restores_core.Restore,
backup *backups_core.Backup,
storage *storages.Storage,
) error {
@@ -79,6 +80,7 @@ func (uc *RestoreMariadbBackupUsecase) Execute(
}
return uc.restoreFromStorage(
parentCtx,
originalDB,
tools.GetMariadbExecutable(
tools.MariadbExecutableMariadb,
@@ -95,6 +97,7 @@ func (uc *RestoreMariadbBackupUsecase) Execute(
}
func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
parentCtx context.Context,
database *databases.Database,
mariadbBin string,
args []string,
@@ -103,7 +106,7 @@ func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
storage *storages.Storage,
mdbConfig *mariadbtypes.MariadbDatabase,
) error {
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
defer cancel()
go func() {
@@ -113,6 +116,9 @@ func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
select {
case <-ctx.Done():
return
case <-parentCtx.Done():
cancel()
return
case <-ticker.C:
if config.IsShouldShutdown() {
cancel()
@@ -213,6 +219,15 @@ func (uc *RestoreMariadbBackupUsecase) executeMariadbRestore(
waitErr := cmd.Wait()
stderrOutput := <-stderrCh
// Check for cancellation
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.Canceled) {
return fmt.Errorf("restore cancelled")
}
default:
}
if config.IsShouldShutdown() {
return fmt.Errorf("restore cancelled due to shutdown")
}

View File

@@ -20,7 +20,7 @@ import (
"databasus-backend/internal/features/databases"
mongodbtypes "databasus-backend/internal/features/databases/databases/mongodb"
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
"databasus-backend/internal/features/restores/models"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
util_encryption "databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/tools"
@@ -36,10 +36,11 @@ type RestoreMongodbBackupUsecase struct {
}
func (uc *RestoreMongodbBackupUsecase) Execute(
parentCtx context.Context,
originalDB *databases.Database,
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
restore models.Restore,
restore restores_core.Restore,
backup *backups_core.Backup,
storage *storages.Storage,
) error {
@@ -76,6 +77,7 @@ func (uc *RestoreMongodbBackupUsecase) Execute(
args := uc.buildMongorestoreArgs(mdb, decryptedPassword, sourceDatabase)
return uc.restoreFromStorage(
parentCtx,
tools.GetMongodbExecutable(
tools.MongodbExecutableMongorestore,
config.GetEnv().EnvMode,
@@ -122,12 +124,13 @@ func (uc *RestoreMongodbBackupUsecase) buildMongorestoreArgs(
}
func (uc *RestoreMongodbBackupUsecase) restoreFromStorage(
parentCtx context.Context,
mongorestoreBin string,
args []string,
backup *backups_core.Backup,
storage *storages.Storage,
) error {
ctx, cancel := context.WithTimeout(context.Background(), restoreTimeout)
ctx, cancel := context.WithTimeout(parentCtx, restoreTimeout)
defer cancel()
go func() {
@@ -137,6 +140,9 @@ func (uc *RestoreMongodbBackupUsecase) restoreFromStorage(
select {
case <-ctx.Done():
return
case <-parentCtx.Done():
cancel()
return
case <-ticker.C:
if config.IsShouldShutdown() {
cancel()
@@ -218,6 +224,15 @@ func (uc *RestoreMongodbBackupUsecase) executeMongoRestore(
waitErr := cmd.Wait()
stderrOutput := <-stderrCh
// Check for cancellation
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.Canceled) {
return fmt.Errorf("restore cancelled")
}
default:
}
if config.IsShouldShutdown() {
return fmt.Errorf("restore cancelled due to shutdown")
}

View File

@@ -24,7 +24,7 @@ import (
"databasus-backend/internal/features/databases"
mysqltypes "databasus-backend/internal/features/databases/databases/mysql"
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
"databasus-backend/internal/features/restores/models"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
util_encryption "databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/tools"
@@ -36,10 +36,11 @@ type RestoreMysqlBackupUsecase struct {
}
func (uc *RestoreMysqlBackupUsecase) Execute(
parentCtx context.Context,
originalDB *databases.Database,
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
restore models.Restore,
restore restores_core.Restore,
backup *backups_core.Backup,
storage *storages.Storage,
) error {
@@ -78,6 +79,7 @@ func (uc *RestoreMysqlBackupUsecase) Execute(
}
return uc.restoreFromStorage(
parentCtx,
originalDB,
tools.GetMysqlExecutable(
my.Version,
@@ -94,6 +96,7 @@ func (uc *RestoreMysqlBackupUsecase) Execute(
}
func (uc *RestoreMysqlBackupUsecase) restoreFromStorage(
parentCtx context.Context,
database *databases.Database,
mysqlBin string,
args []string,
@@ -102,7 +105,7 @@ func (uc *RestoreMysqlBackupUsecase) restoreFromStorage(
storage *storages.Storage,
myConfig *mysqltypes.MysqlDatabase,
) error {
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
defer cancel()
go func() {
@@ -112,6 +115,9 @@ func (uc *RestoreMysqlBackupUsecase) restoreFromStorage(
select {
case <-ctx.Done():
return
case <-parentCtx.Done():
cancel()
return
case <-ticker.C:
if config.IsShouldShutdown() {
cancel()
@@ -204,6 +210,15 @@ func (uc *RestoreMysqlBackupUsecase) executeMysqlRestore(
waitErr := cmd.Wait()
stderrOutput := <-stderrCh
// Check for cancellation
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.Canceled) {
return fmt.Errorf("restore cancelled")
}
default:
}
if config.IsShouldShutdown() {
return fmt.Errorf("restore cancelled due to shutdown")
}

View File

@@ -21,7 +21,7 @@ import (
"databasus-backend/internal/features/databases"
pgtypes "databasus-backend/internal/features/databases/databases/postgresql"
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
"databasus-backend/internal/features/restores/models"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
util_encryption "databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/tools"
@@ -35,10 +35,11 @@ type RestorePostgresqlBackupUsecase struct {
}
func (uc *RestorePostgresqlBackupUsecase) Execute(
parentCtx context.Context,
originalDB *databases.Database,
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
restore models.Restore,
restore restores_core.Restore,
backup *backups_core.Backup,
storage *storages.Storage,
isExcludeExtensions bool,
@@ -73,6 +74,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
// All PostgreSQL backups are now custom format (-Fc)
return uc.restoreCustomType(
parentCtx,
originalDB,
pgBin,
backup,
@@ -84,6 +86,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
// restoreCustomType restores a backup in custom type (-Fc)
func (uc *RestorePostgresqlBackupUsecase) restoreCustomType(
parentCtx context.Context,
originalDB *databases.Database,
pgBin string,
backup *backups_core.Backup,
@@ -102,15 +105,24 @@ func (uc *RestorePostgresqlBackupUsecase) restoreCustomType(
// If excluding extensions, we must use file-based restore (requires TOC file generation)
// Also use file-based restore for parallel jobs (multiple CPUs)
if isExcludeExtensions || pg.CpuCount > 1 {
return uc.restoreViaFile(originalDB, pgBin, backup, storage, pg, isExcludeExtensions)
return uc.restoreViaFile(
parentCtx,
originalDB,
pgBin,
backup,
storage,
pg,
isExcludeExtensions,
)
}
// Single CPU without extension exclusion: stream directly via stdin
return uc.restoreViaStdin(originalDB, pgBin, backup, storage, pg)
return uc.restoreViaStdin(parentCtx, originalDB, pgBin, backup, storage, pg)
}
// restoreViaStdin streams backup via stdin for single CPU restore
func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
parentCtx context.Context,
originalDB *databases.Database,
pgBin string,
backup *backups_core.Backup,
@@ -133,10 +145,10 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
"--no-acl",
}
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
defer cancel()
// Monitor for shutdown and cancel context if needed
// Monitor for shutdown and parent cancellation
go func() {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
@@ -145,6 +157,9 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
select {
case <-ctx.Done():
return
case <-parentCtx.Done():
cancel()
return
case <-ticker.C:
if config.IsShouldShutdown() {
cancel()
@@ -296,6 +311,15 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
stderrOutput := <-stderrCh
copyErr := <-copyErrCh
// Check for cancellation
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.Canceled) {
return fmt.Errorf("restore cancelled")
}
default:
}
// Check for shutdown before finalizing
if config.IsShouldShutdown() {
return fmt.Errorf("restore cancelled due to shutdown")
@@ -307,6 +331,15 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
}
if waitErr != nil {
// Check for cancellation again
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.Canceled) {
return fmt.Errorf("restore cancelled")
}
default:
}
if config.IsShouldShutdown() {
return fmt.Errorf("restore cancelled due to shutdown")
}
@@ -319,6 +352,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
// restoreViaFile downloads backup and uses parallel jobs for multi-CPU restore
func (uc *RestorePostgresqlBackupUsecase) restoreViaFile(
parentCtx context.Context,
originalDB *databases.Database,
pgBin string,
backup *backups_core.Backup,
@@ -354,6 +388,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaFile(
}
return uc.restoreFromStorage(
parentCtx,
originalDB,
pgBin,
args,
@@ -367,6 +402,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaFile(
// restoreFromStorage restores backup data from storage using pg_restore
func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
parentCtx context.Context,
database *databases.Database,
pgBin string,
args []string,
@@ -386,10 +422,10 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
isExcludeExtensions,
)
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
defer cancel()
// Monitor for shutdown and cancel context if needed
// Monitor for shutdown and parent cancellation
go func() {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
@@ -398,6 +434,9 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
select {
case <-ctx.Done():
return
case <-parentCtx.Done():
cancel()
return
case <-ticker.C:
if config.IsShouldShutdown() {
cancel()
@@ -624,12 +663,30 @@ func (uc *RestorePostgresqlBackupUsecase) executePgRestore(
waitErr := cmd.Wait()
stderrOutput := <-stderrCh
// Check for cancellation
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.Canceled) {
return fmt.Errorf("restore cancelled")
}
default:
}
// Check for shutdown before finalizing
if config.IsShouldShutdown() {
return fmt.Errorf("restore cancelled due to shutdown")
}
if waitErr != nil {
// Check for cancellation again
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.Canceled) {
return fmt.Errorf("restore cancelled")
}
default:
}
if config.IsShouldShutdown() {
return fmt.Errorf("restore cancelled due to shutdown")
}

View File

@@ -1,12 +1,13 @@
package usecases
import (
"context"
"errors"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/restores/models"
restores_core "databasus-backend/internal/features/restores/core"
usecases_mariadb "databasus-backend/internal/features/restores/usecases/mariadb"
usecases_mongodb "databasus-backend/internal/features/restores/usecases/mongodb"
usecases_mysql "databasus-backend/internal/features/restores/usecases/mysql"
@@ -22,8 +23,9 @@ type RestoreBackupUsecase struct {
}
func (uc *RestoreBackupUsecase) Execute(
ctx context.Context,
backupConfig *backups_config.BackupConfig,
restore models.Restore,
restore restores_core.Restore,
originalDB *databases.Database,
restoringToDB *databases.Database,
backup *backups_core.Backup,
@@ -33,6 +35,7 @@ func (uc *RestoreBackupUsecase) Execute(
switch originalDB.Type {
case databases.DatabaseTypePostgres:
return uc.restorePostgresqlBackupUsecase.Execute(
ctx,
originalDB,
restoringToDB,
backupConfig,
@@ -43,6 +46,7 @@ func (uc *RestoreBackupUsecase) Execute(
)
case databases.DatabaseTypeMysql:
return uc.restoreMysqlBackupUsecase.Execute(
ctx,
originalDB,
restoringToDB,
backupConfig,
@@ -52,6 +56,7 @@ func (uc *RestoreBackupUsecase) Execute(
)
case databases.DatabaseTypeMariadb:
return uc.restoreMariadbBackupUsecase.Execute(
ctx,
originalDB,
restoringToDB,
backupConfig,
@@ -61,6 +66,7 @@ func (uc *RestoreBackupUsecase) Execute(
)
case databases.DatabaseTypeMongodb:
return uc.restoreMongodbBackupUsecase.Execute(
ctx,
originalDB,
restoringToDB,
backupConfig,

View File

@@ -6,6 +6,7 @@ import (
"strings"
"testing"
"databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
azure_blob_storage "databasus-backend/internal/features/storages/models/azure_blob"
ftp_storage "databasus-backend/internal/features/storages/models/ftp"
@@ -902,6 +903,12 @@ func Test_StorageSensitiveDataLifecycle_AllTypes(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Skip Google Drive tests if external resources tests are disabled
if tc.storageType == StorageTypeGoogleDrive &&
config.GetEnv().IsSkipExternalResourcesTests {
t.Skip("Skipping Google Drive storage test: IS_SKIP_EXTERNAL_RESOURCES_TESTS=true")
}
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)

View File

@@ -1,9 +1,13 @@
package storages
import (
"sync"
"sync/atomic"
audit_logs "databasus-backend/internal/features/audit_logs"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
)
var storageRepository = &StorageRepository{}
@@ -27,6 +31,21 @@ func GetStorageController() *StorageController {
return storageController
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(storageService)
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(storageService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -113,7 +113,7 @@ func Test_Storage_BasicOperations(t *testing.T) {
name: "NASStorage",
storage: &nas_storage.NASStorage{
StorageID: uuid.New(),
Host: "localhost",
Host: config.GetEnv().TestLocalhost,
Port: nasPort,
Share: "backups",
Username: "testuser",
@@ -147,7 +147,7 @@ func Test_Storage_BasicOperations(t *testing.T) {
name: "FTPStorage",
storage: &ftp_storage.FTPStorage{
StorageID: uuid.New(),
Host: "localhost",
Host: config.GetEnv().TestLocalhost,
Port: ftpPort,
Username: "testuser",
Password: "testpassword",
@@ -159,7 +159,7 @@ func Test_Storage_BasicOperations(t *testing.T) {
name: "SFTPStorage",
storage: &sftp_storage.SFTPStorage{
StorageID: uuid.New(),
Host: "localhost",
Host: config.GetEnv().TestLocalhost,
Port: sftpPort,
Username: "testuser",
Password: "testpassword",
@@ -185,7 +185,9 @@ acl = private`, s3Container.accessKey, s3Container.secretKey, s3Container.endpoi
// Add Google Drive storage test only if environment variables are available
env := config.GetEnv()
if env.TestGoogleDriveClientID != "" && env.TestGoogleDriveClientSecret != "" &&
if env.IsSkipExternalResourcesTests {
t.Log("Skipping Google Drive storage test: IS_SKIP_EXTERNAL_RESOURCES_TESTS=true")
} else if env.TestGoogleDriveClientID != "" && env.TestGoogleDriveClientSecret != "" &&
env.TestGoogleDriveTokenJSON != "" {
testCases = append(testCases, struct {
name string
@@ -297,7 +299,7 @@ func setupS3Container(ctx context.Context) (*S3Container, error) {
secretKey := "testpassword"
bucketName := "test-bucket"
region := "us-east-1"
endpoint := fmt.Sprintf("127.0.0.1:%s", env.TestMinioPort)
endpoint := fmt.Sprintf("%s:%s", env.TestLocalhost, env.TestMinioPort)
// Create MinIO client and ensure bucket exists
minioClient, err := minio.New(endpoint, &minio.Options{
@@ -343,15 +345,21 @@ func setupAzuriteContainer(ctx context.Context) (*AzuriteContainer, error) {
accountName := "devstoreaccount1"
// this is real testing key for azurite, it's not a real key
accountKey := "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="
serviceURL := fmt.Sprintf("http://127.0.0.1:%s/%s", env.TestAzuriteBlobPort, accountName)
serviceURL := fmt.Sprintf(
"http://%s:%s/%s",
env.TestLocalhost,
env.TestAzuriteBlobPort,
accountName,
)
containerNameKey := "test-container-key"
containerNameStr := "test-container-connstr"
// Build explicit connection string for Azurite
connectionString := fmt.Sprintf(
"DefaultEndpointsProtocol=http;AccountName=%s;AccountKey=%s;BlobEndpoint=http://127.0.0.1:%s/%s",
"DefaultEndpointsProtocol=http;AccountName=%s;AccountKey=%s;BlobEndpoint=http://%s:%s/%s",
accountName,
accountKey,
env.TestLocalhost,
env.TestAzuriteBlobPort,
accountName,
)

View File

@@ -285,30 +285,30 @@ func (f *FTPStorage) ensureDirectory(conn *ftp.ServerConn, path string) error {
}
parts := strings.Split(path, "/")
currentPath := ""
currentDir, err := conn.CurrentDir()
if err != nil {
return fmt.Errorf("failed to get current directory: %w", err)
}
defer func() {
_ = conn.ChangeDir(currentDir)
}()
for _, part := range parts {
if part == "" || part == "." {
continue
}
if currentPath == "" {
currentPath = part
} else {
currentPath = currentPath + "/" + part
}
err := conn.ChangeDir(currentPath)
err := conn.ChangeDir(part)
if err != nil {
err = conn.MakeDir(currentPath)
err = conn.MakeDir(part)
if err != nil {
return fmt.Errorf("failed to create directory '%s': %w", currentPath, err)
return fmt.Errorf("failed to create directory '%s': %w", part, err)
}
err = conn.ChangeDir(part)
if err != nil {
return fmt.Errorf("failed to change into directory '%s': %w", part, err)
}
}
err = conn.ChangeDirToParent()
if err != nil {
return fmt.Errorf("failed to change to parent directory: %w", err)
}
}

View File

@@ -37,7 +37,7 @@ func (s *HealthcheckService) IsHealthy() error {
}
}
if config.GetEnv().IsBackupNode {
if config.GetEnv().IsProcessingNode {
if !s.backuperNode.IsBackuperRunning() {
return errors.New("backuper node is not running for more than 5 minutes")
}

View File

@@ -0,0 +1,75 @@
package task_cancellation
import (
"context"
cache_utils "databasus-backend/internal/util/cache"
"log/slog"
"sync"
"github.com/google/uuid"
)
const taskCancelChannel = "task:cancel"
type TaskCancelManager struct {
mu sync.RWMutex
cancelFuncs map[uuid.UUID]context.CancelFunc
pubsub *cache_utils.PubSubManager
logger *slog.Logger
}
func (m *TaskCancelManager) StartSubscription() {
ctx := context.Background()
handler := func(message string) {
taskID, err := uuid.Parse(message)
if err != nil {
m.logger.Error("Invalid task ID in cancel message", "message", message, "error", err)
return
}
m.mu.Lock()
defer m.mu.Unlock()
cancelFunc, exists := m.cancelFuncs[taskID]
if exists {
cancelFunc()
delete(m.cancelFuncs, taskID)
m.logger.Info("Cancelled task via Pub/Sub", "taskID", taskID)
}
}
err := m.pubsub.Subscribe(ctx, taskCancelChannel, handler)
if err != nil {
m.logger.Error("Failed to subscribe to task cancel channel", "error", err)
} else {
m.logger.Info("Successfully subscribed to task cancel channel")
}
}
func (m *TaskCancelManager) RegisterTask(task uuid.UUID, cancelFunc context.CancelFunc) {
m.mu.Lock()
defer m.mu.Unlock()
m.cancelFuncs[task] = cancelFunc
m.logger.Debug("Registered task", "taskID", task)
}
func (m *TaskCancelManager) CancelTask(taskID uuid.UUID) error {
ctx := context.Background()
err := m.pubsub.Publish(ctx, taskCancelChannel, taskID.String())
if err != nil {
m.logger.Error("Failed to publish cancel message", "taskID", taskID, "error", err)
return err
}
m.logger.Info("Published task cancel message", "taskID", taskID)
return nil
}
func (m *TaskCancelManager) UnregisterTask(taskID uuid.UUID) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.cancelFuncs, taskID)
m.logger.Debug("Unregistered task", "taskID", taskID)
}

View File

@@ -1,4 +1,4 @@
package backups_cancellation
package task_cancellation
import (
"context"
@@ -10,41 +10,41 @@ import (
"github.com/stretchr/testify/assert"
)
func Test_RegisterBackup_BackupRegisteredSuccessfully(t *testing.T) {
manager := backupCancelManager
func Test_RegisterTask_TaskRegisteredSuccessfully(t *testing.T) {
manager := taskCancelManager
backupID := uuid.New()
taskID := uuid.New()
_, cancel := context.WithCancel(context.Background())
defer cancel()
manager.RegisterBackup(backupID, cancel)
manager.RegisterTask(taskID, cancel)
manager.mu.RLock()
_, exists := manager.cancelFuncs[backupID]
_, exists := manager.cancelFuncs[taskID]
manager.mu.RUnlock()
assert.True(t, exists, "Backup should be registered")
assert.True(t, exists, "Task should be registered")
}
func Test_UnregisterBackup_BackupUnregisteredSuccessfully(t *testing.T) {
manager := backupCancelManager
func Test_UnregisterTask_TaskUnregisteredSuccessfully(t *testing.T) {
manager := taskCancelManager
backupID := uuid.New()
taskID := uuid.New()
_, cancel := context.WithCancel(context.Background())
defer cancel()
manager.RegisterBackup(backupID, cancel)
manager.UnregisterBackup(backupID)
manager.RegisterTask(taskID, cancel)
manager.UnregisterTask(taskID)
manager.mu.RLock()
_, exists := manager.cancelFuncs[backupID]
_, exists := manager.cancelFuncs[taskID]
manager.mu.RUnlock()
assert.False(t, exists, "Backup should be unregistered")
assert.False(t, exists, "Task should be unregistered")
}
func Test_CancelBackup_OnSameInstance_BackupCancelledViaPubSub(t *testing.T) {
manager := backupCancelManager
func Test_CancelTask_OnSameInstance_TaskCancelledViaPubSub(t *testing.T) {
manager := taskCancelManager
backupID := uuid.New()
taskID := uuid.New()
ctx, cancel := context.WithCancel(context.Background())
cancelled := false
@@ -57,11 +57,11 @@ func Test_CancelBackup_OnSameInstance_BackupCancelledViaPubSub(t *testing.T) {
cancel()
}
manager.RegisterBackup(backupID, wrappedCancel)
manager.RegisterTask(taskID, wrappedCancel)
manager.StartSubscription()
time.Sleep(100 * time.Millisecond)
err := manager.CancelBackup(backupID)
err := manager.CancelTask(taskID)
assert.NoError(t, err, "Cancel should not return error")
time.Sleep(500 * time.Millisecond)
@@ -74,11 +74,11 @@ func Test_CancelBackup_OnSameInstance_BackupCancelledViaPubSub(t *testing.T) {
assert.Error(t, ctx.Err(), "Context should be cancelled")
}
func Test_CancelBackup_FromDifferentInstance_BackupCancelledOnRunningInstance(t *testing.T) {
manager1 := backupCancelManager
manager2 := backupCancelManager
func Test_CancelTask_FromDifferentInstance_TaskCancelledOnRunningInstance(t *testing.T) {
manager1 := taskCancelManager
manager2 := taskCancelManager
backupID := uuid.New()
taskID := uuid.New()
ctx, cancel := context.WithCancel(context.Background())
cancelled := false
@@ -91,13 +91,13 @@ func Test_CancelBackup_FromDifferentInstance_BackupCancelledOnRunningInstance(t
cancel()
}
manager1.RegisterBackup(backupID, wrappedCancel)
manager1.RegisterTask(taskID, wrappedCancel)
manager1.StartSubscription()
manager2.StartSubscription()
time.Sleep(100 * time.Millisecond)
err := manager2.CancelBackup(backupID)
err := manager2.CancelTask(taskID)
assert.NoError(t, err, "Cancel should not return error")
time.Sleep(500 * time.Millisecond)
@@ -110,29 +110,29 @@ func Test_CancelBackup_FromDifferentInstance_BackupCancelledOnRunningInstance(t
assert.Error(t, ctx.Err(), "Context should be cancelled")
}
func Test_CancelBackup_WhenBackupDoesNotExist_NoErrorReturned(t *testing.T) {
manager := backupCancelManager
func Test_CancelTask_WhenTaskDoesNotExist_NoErrorReturned(t *testing.T) {
manager := taskCancelManager
manager.StartSubscription()
time.Sleep(100 * time.Millisecond)
nonExistentID := uuid.New()
err := manager.CancelBackup(nonExistentID)
assert.NoError(t, err, "Cancelling non-existent backup should not error")
err := manager.CancelTask(nonExistentID)
assert.NoError(t, err, "Cancelling non-existent task should not error")
}
func Test_CancelBackup_WithMultipleBackups_AllBackupsCancelled(t *testing.T) {
manager := backupCancelManager
func Test_CancelTask_WithMultipleTasks_AllTasksCancelled(t *testing.T) {
manager := taskCancelManager
numBackups := 5
backupIDs := make([]uuid.UUID, numBackups)
contexts := make([]context.Context, numBackups)
cancels := make([]context.CancelFunc, numBackups)
cancelledFlags := make([]bool, numBackups)
numTasks := 5
taskIDs := make([]uuid.UUID, numTasks)
contexts := make([]context.Context, numTasks)
cancels := make([]context.CancelFunc, numTasks)
cancelledFlags := make([]bool, numTasks)
var mu sync.Mutex
for i := 0; i < numBackups; i++ {
backupIDs[i] = uuid.New()
for i := 0; i < numTasks; i++ {
taskIDs[i] = uuid.New()
contexts[i], cancels[i] = context.WithCancel(context.Background())
idx := i
@@ -143,31 +143,31 @@ func Test_CancelBackup_WithMultipleBackups_AllBackupsCancelled(t *testing.T) {
cancels[idx]()
}
manager.RegisterBackup(backupIDs[i], wrappedCancel)
manager.RegisterTask(taskIDs[i], wrappedCancel)
}
manager.StartSubscription()
time.Sleep(100 * time.Millisecond)
for i := 0; i < numBackups; i++ {
err := manager.CancelBackup(backupIDs[i])
for i := 0; i < numTasks; i++ {
err := manager.CancelTask(taskIDs[i])
assert.NoError(t, err, "Cancel should not return error")
}
time.Sleep(1 * time.Second)
mu.Lock()
for i := 0; i < numBackups; i++ {
assert.True(t, cancelledFlags[i], "Backup %d should be cancelled", i)
for i := 0; i < numTasks; i++ {
assert.True(t, cancelledFlags[i], "Task %d should be cancelled", i)
assert.Error(t, contexts[i].Err(), "Context %d should be cancelled", i)
}
mu.Unlock()
}
func Test_CancelBackup_AfterUnregister_BackupNotCancelled(t *testing.T) {
manager := backupCancelManager
func Test_CancelTask_AfterUnregister_TaskNotCancelled(t *testing.T) {
manager := taskCancelManager
backupID := uuid.New()
taskID := uuid.New()
_, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -181,13 +181,13 @@ func Test_CancelBackup_AfterUnregister_BackupNotCancelled(t *testing.T) {
cancel()
}
manager.RegisterBackup(backupID, wrappedCancel)
manager.RegisterTask(taskID, wrappedCancel)
manager.StartSubscription()
time.Sleep(100 * time.Millisecond)
manager.UnregisterBackup(backupID)
manager.UnregisterTask(taskID)
err := manager.CancelBackup(backupID)
err := manager.CancelTask(taskID)
assert.NoError(t, err, "Cancel should not return error")
time.Sleep(500 * time.Millisecond)

View File

@@ -0,0 +1,42 @@
package task_cancellation
import (
"context"
"sync"
"sync/atomic"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/logger"
"github.com/google/uuid"
)
var taskCancelManager = &TaskCancelManager{
sync.RWMutex{},
make(map[uuid.UUID]context.CancelFunc),
cache_utils.NewPubSubManager(),
logger.GetLogger(),
}
func GetTaskCancelManager() *TaskCancelManager {
return taskCancelManager
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
taskCancelManager.StartSubscription()
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -0,0 +1,159 @@
package features
import (
"context"
"fmt"
"sync"
"testing"
"time"
"databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
"databasus-backend/internal/features/backups/backups/backuping"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/restores"
"databasus-backend/internal/features/restores/restoring"
"databasus-backend/internal/features/storages"
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
)
// Test_SetupDependencies_CalledTwice_LogsWarning verifies SetupDependencies is idempotent
func Test_SetupDependencies_CalledTwice_LogsWarning(t *testing.T) {
// Call each SetupDependencies twice - should not panic, only log warnings
audit_logs.SetupDependencies()
audit_logs.SetupDependencies()
backups.SetupDependencies()
backups.SetupDependencies()
backups_config.SetupDependencies()
backups_config.SetupDependencies()
databases.SetupDependencies()
databases.SetupDependencies()
healthcheck_config.SetupDependencies()
healthcheck_config.SetupDependencies()
notifiers.SetupDependencies()
notifiers.SetupDependencies()
restores.SetupDependencies()
restores.SetupDependencies()
storages.SetupDependencies()
storages.SetupDependencies()
task_cancellation.SetupDependencies()
task_cancellation.SetupDependencies()
// If we reach here without panic, test passes
t.Log("All SetupDependencies calls completed successfully (idempotent)")
}
// Test_SetupDependencies_ConcurrentCalls_Safe verifies thread safety
func Test_SetupDependencies_ConcurrentCalls_Safe(t *testing.T) {
var wg sync.WaitGroup
// Call SetupDependencies concurrently from 10 goroutines
for range 10 {
wg.Add(1)
go func() {
defer wg.Done()
audit_logs.SetupDependencies()
}()
}
wg.Wait()
t.Log("Concurrent SetupDependencies calls completed successfully")
}
// Test_BackgroundService_Run_CalledTwice_Panics verifies Run() panics on duplicate calls
func Test_BackgroundService_Run_CalledTwice_Panics(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Create a test background service
backgroundService := audit_logs.GetAuditLogBackgroundService()
// Start first Run() in goroutine
go func() {
backgroundService.Run(ctx)
}()
// Give first call time to initialize
time.Sleep(100 * time.Millisecond)
// Second call should panic
defer func() {
if r := recover(); r != nil {
expectedMsg := "*audit_logs.AuditLogBackgroundService.Run() called multiple times"
panicMsg := fmt.Sprintf("%v", r)
if panicMsg == expectedMsg {
t.Logf("Successfully caught panic: %v", r)
} else {
t.Errorf("Expected panic message '%s', got '%s'", expectedMsg, panicMsg)
}
} else {
t.Error("Expected panic on second Run() call, but did not panic")
}
}()
backgroundService.Run(ctx)
}
// Test_BackupsScheduler_Run_CalledTwice_Panics verifies scheduler panics on duplicate calls
func Test_BackupsScheduler_Run_CalledTwice_Panics(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
scheduler := backuping.GetBackupsScheduler()
// Start first Run() in goroutine
go func() {
scheduler.Run(ctx)
}()
// Give first call time to initialize
time.Sleep(100 * time.Millisecond)
// Second call should panic
defer func() {
if r := recover(); r != nil {
t.Logf("Successfully caught panic: %v", r)
} else {
t.Error("Expected panic on second Run() call, but did not panic")
}
}()
scheduler.Run(ctx)
}
// Test_RestoresScheduler_Run_CalledTwice_Panics verifies restore scheduler panics on duplicate calls
func Test_RestoresScheduler_Run_CalledTwice_Panics(t *testing.T) {
ctx := t.Context()
scheduler := restoring.GetRestoresScheduler()
// Start first Run() in goroutine
go func() {
scheduler.Run(ctx)
}()
// Give first call time to initialize
time.Sleep(100 * time.Millisecond)
// Second call should panic
defer func() {
if r := recover(); r != nil {
t.Logf("Successfully caught panic: %v", r)
} else {
t.Error("Expected panic on second Run() call, but did not panic")
}
}()
scheduler.Run(ctx)
}

View File

@@ -21,9 +21,7 @@ import (
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
mariadbtypes "databasus-backend/internal/features/databases/databases/mariadb"
"databasus-backend/internal/features/restores"
restores_enums "databasus-backend/internal/features/restores/enums"
restores_models "databasus-backend/internal/features/restores/models"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
@@ -213,7 +211,7 @@ func testMariadbBackupRestoreForVersion(
)
restore := waitForMariadbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
var tableExists int
err = newDB.Get(
@@ -311,7 +309,7 @@ func testMariadbBackupRestoreWithEncryptionForVersion(
)
restore := waitForMariadbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
var tableExists int
err = newDB.Get(
@@ -418,7 +416,7 @@ func testMariadbBackupRestoreWithReadOnlyUserForVersion(
)
restore := waitForMariadbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
var tableExists int
err = newDB.Get(
@@ -506,7 +504,7 @@ func createMariadbRestoreViaAPI(
version tools.MariadbVersion,
token string,
) {
request := restores.RestoreBackupRequest{
request := restores_core.RestoreBackupRequest{
MariadbDatabase: &mariadbtypes.MariadbDatabase{
Host: host,
Port: port,
@@ -533,7 +531,7 @@ func waitForMariadbRestoreCompletion(
backupID uuid.UUID,
token string,
timeout time.Duration,
) *restores_models.Restore {
) *restores_core.Restore {
startTime := time.Now()
pollInterval := 500 * time.Millisecond
@@ -542,7 +540,7 @@ func waitForMariadbRestoreCompletion(
t.Fatalf("Timeout waiting for MariaDB restore completion after %v", timeout)
}
var restoresList []*restores_models.Restore
var restoresList []*restores_core.Restore
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
@@ -553,10 +551,10 @@ func waitForMariadbRestoreCompletion(
)
for _, restore := range restoresList {
if restore.Status == restores_enums.RestoreStatusCompleted {
if restore.Status == restores_core.RestoreStatusCompleted {
return restore
}
if restore.Status == restores_enums.RestoreStatusFailed {
if restore.Status == restores_core.RestoreStatusFailed {
failMsg := "unknown error"
if restore.FailMessage != nil {
failMsg = *restore.FailMessage
@@ -607,7 +605,7 @@ func connectToMariadbContainer(
dbName := "testdb"
password := "rootpassword"
username := "root"
host := "127.0.0.1"
host := config.GetEnv().TestLocalhost
portInt, err := strconv.Atoi(port)
if err != nil {

View File

@@ -23,9 +23,7 @@ import (
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
mongodbtypes "databasus-backend/internal/features/databases/databases/mongodb"
"databasus-backend/internal/features/restores"
restores_enums "databasus-backend/internal/features/restores/enums"
restores_models "databasus-backend/internal/features/restores/models"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
@@ -175,7 +173,7 @@ func testMongodbBackupRestoreForVersion(
)
restore := waitForMongodbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
verifyMongodbDataIntegrity(t, container, newDBName)
@@ -254,7 +252,7 @@ func testMongodbBackupRestoreWithEncryptionForVersion(
)
restore := waitForMongodbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
verifyMongodbDataIntegrity(t, container, newDBName)
@@ -342,7 +340,7 @@ func testMongodbBackupRestoreWithReadOnlyUserForVersion(
)
restore := waitForMongodbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
verifyMongodbDataIntegrity(t, container, newDBName)
@@ -431,7 +429,7 @@ func createMongodbRestoreViaAPI(
version tools.MongodbVersion,
token string,
) {
request := restores.RestoreBackupRequest{
request := restores_core.RestoreBackupRequest{
MongodbDatabase: &mongodbtypes.MongodbDatabase{
Host: host,
Port: port,
@@ -461,7 +459,7 @@ func waitForMongodbRestoreCompletion(
backupID uuid.UUID,
token string,
timeout time.Duration,
) *restores_models.Restore {
) *restores_core.Restore {
startTime := time.Now()
pollInterval := 500 * time.Millisecond
@@ -470,7 +468,7 @@ func waitForMongodbRestoreCompletion(
t.Fatalf("Timeout waiting for MongoDB restore completion after %v", timeout)
}
var restoresList []*restores_models.Restore
var restoresList []*restores_core.Restore
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
@@ -481,10 +479,10 @@ func waitForMongodbRestoreCompletion(
)
for _, restore := range restoresList {
if restore.Status == restores_enums.RestoreStatusCompleted {
if restore.Status == restores_core.RestoreStatusCompleted {
return restore
}
if restore.Status == restores_enums.RestoreStatusFailed {
if restore.Status == restores_core.RestoreStatusFailed {
failMsg := "unknown error"
if restore.FailMessage != nil {
failMsg = *restore.FailMessage
@@ -552,7 +550,7 @@ func connectToMongodbContainer(
password := "rootpassword"
username := "root"
authDatabase := "admin"
host := "127.0.0.1"
host := config.GetEnv().TestLocalhost
portInt, err := strconv.Atoi(port)
if err != nil {
@@ -560,8 +558,13 @@ func connectToMongodbContainer(
}
uri := fmt.Sprintf(
"mongodb://%s:%s@%s:%d/%s?authSource=%s",
username, password, host, portInt, dbName, authDatabase,
"mongodb://%s:%s@%s:%d/%s?authSource=%s&serverSelectionTimeoutMS=5000&connectTimeoutMS=5000",
username,
password,
host,
portInt,
dbName,
authDatabase,
)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)

View File

@@ -21,9 +21,7 @@ import (
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
mysqltypes "databasus-backend/internal/features/databases/databases/mysql"
"databasus-backend/internal/features/restores"
restores_enums "databasus-backend/internal/features/restores/enums"
restores_models "databasus-backend/internal/features/restores/models"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
@@ -188,7 +186,7 @@ func testMysqlBackupRestoreForVersion(t *testing.T, mysqlVersion tools.MysqlVers
)
restore := waitForMysqlRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
var tableExists int
err = newDB.Get(
@@ -286,7 +284,7 @@ func testMysqlBackupRestoreWithEncryptionForVersion(
)
restore := waitForMysqlRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
var tableExists int
err = newDB.Get(
@@ -393,7 +391,7 @@ func testMysqlBackupRestoreWithReadOnlyUserForVersion(
)
restore := waitForMysqlRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
var tableExists int
err = newDB.Get(
@@ -481,7 +479,7 @@ func createMysqlRestoreViaAPI(
version tools.MysqlVersion,
token string,
) {
request := restores.RestoreBackupRequest{
request := restores_core.RestoreBackupRequest{
MysqlDatabase: &mysqltypes.MysqlDatabase{
Host: host,
Port: port,
@@ -508,7 +506,7 @@ func waitForMysqlRestoreCompletion(
backupID uuid.UUID,
token string,
timeout time.Duration,
) *restores_models.Restore {
) *restores_core.Restore {
startTime := time.Now()
pollInterval := 500 * time.Millisecond
@@ -517,7 +515,7 @@ func waitForMysqlRestoreCompletion(
t.Fatalf("Timeout waiting for MySQL restore completion after %v", timeout)
}
var restoresList []*restores_models.Restore
var restoresList []*restores_core.Restore
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
@@ -528,10 +526,10 @@ func waitForMysqlRestoreCompletion(
)
for _, restore := range restoresList {
if restore.Status == restores_enums.RestoreStatusCompleted {
if restore.Status == restores_core.RestoreStatusCompleted {
return restore
}
if restore.Status == restores_enums.RestoreStatusFailed {
if restore.Status == restores_core.RestoreStatusFailed {
failMsg := "unknown error"
if restore.FailMessage != nil {
failMsg = *restore.FailMessage
@@ -579,7 +577,7 @@ func connectToMysqlContainer(version tools.MysqlVersion, port string) (*MysqlCon
dbName := "testdb"
password := "rootpassword"
username := "root"
host := "127.0.0.1"
host := config.GetEnv().TestLocalhost
portInt, err := strconv.Atoi(port)
if err != nil {

View File

@@ -23,8 +23,7 @@ import (
"databasus-backend/internal/features/databases"
pgtypes "databasus-backend/internal/features/databases/databases/postgresql"
"databasus-backend/internal/features/restores"
restores_enums "databasus-backend/internal/features/restores/enums"
restores_models "databasus-backend/internal/features/restores/models"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
@@ -125,6 +124,10 @@ func Test_BackupAndRestorePostgresqlWithEncryption_RestoreIsSuccessful(t *testin
}
func Test_BackupAndRestoreSupabase_PublicSchemaOnly_RestoreIsSuccessful(t *testing.T) {
if config.GetEnv().IsSkipExternalResourcesTests {
t.Skip("Skipping Supabase test: IS_SKIP_EXTERNAL_RESOURCES_TESTS is true")
}
env := config.GetEnv()
if env.TestSupabaseHost == "" {
@@ -212,7 +215,7 @@ func Test_BackupAndRestoreSupabase_PublicSchemaOnly_RestoreIsSuccessful(t *testi
)
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
var countAfterRestore int
err = supabaseDB.Get(
@@ -439,7 +442,7 @@ func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string, cp
)
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
var tableExists bool
err = newDB.Get(
@@ -555,7 +558,7 @@ func testSchemaSelectionAllSchemasForVersion(t *testing.T, pgVersion string, por
)
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
var publicTableExists bool
err = newDB.Get(&publicTableExists, `
@@ -689,7 +692,7 @@ func testBackupRestoreWithExcludeExtensionsForVersion(t *testing.T, pgVersion st
)
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
// Verify the table was restored
var tableExists bool
@@ -829,7 +832,7 @@ func testBackupRestoreWithoutExcludeExtensionsForVersion(
)
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
// Verify the extension was recovered
var extensionExists bool
@@ -956,7 +959,7 @@ func testBackupRestoreWithReadOnlyUserForVersion(t *testing.T, pgVersion string,
)
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
var tableExists bool
err = newDB.Get(
@@ -1076,7 +1079,7 @@ func testSchemaSelectionOnlySpecifiedSchemasForVersion(
)
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
var publicTableExists bool
err = newDB.Get(&publicTableExists, `
@@ -1190,7 +1193,7 @@ func testBackupRestoreWithEncryptionForVersion(t *testing.T, pgVersion string, p
)
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
var tableExists bool
err = newDB.Get(
@@ -1286,7 +1289,7 @@ func waitForRestoreCompletion(
backupID uuid.UUID,
token string,
timeout time.Duration,
) *restores_models.Restore {
) *restores_core.Restore {
startTime := time.Now()
pollInterval := 500 * time.Millisecond
@@ -1295,7 +1298,7 @@ func waitForRestoreCompletion(
t.Fatalf("Timeout waiting for restore completion after %v", timeout)
}
var restores []*restores_models.Restore
var restores []*restores_core.Restore
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
@@ -1306,10 +1309,10 @@ func waitForRestoreCompletion(
)
for _, restore := range restores {
if restore.Status == restores_enums.RestoreStatusCompleted {
if restore.Status == restores_core.RestoreStatusCompleted {
return restore
}
if restore.Status == restores_enums.RestoreStatusFailed {
if restore.Status == restores_core.RestoreStatusFailed {
failMsg := "unknown error"
if restore.FailMessage != nil {
failMsg = *restore.FailMessage
@@ -1476,7 +1479,7 @@ func createRestoreWithCpuCountViaAPI(
cpuCount int,
token string,
) {
request := restores.RestoreBackupRequest{
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &pgtypes.PostgresqlDatabase{
Host: host,
Port: port,
@@ -1509,7 +1512,7 @@ func createRestoreWithOptionsViaAPI(
isExcludeExtensions bool,
token string,
) {
request := restores.RestoreBackupRequest{
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &pgtypes.PostgresqlDatabase{
Host: host,
Port: port,
@@ -1647,7 +1650,7 @@ func createSupabaseRestoreViaAPI(
database string,
token string,
) {
request := restores.RestoreBackupRequest{
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &pgtypes.PostgresqlDatabase{
Host: host,
Port: port,
@@ -1755,7 +1758,7 @@ func connectToPostgresContainer(version string, port string) (*PostgresContainer
dbName := "testdb"
password := "testpassword"
username := "testuser"
host := "localhost"
host := config.GetEnv().TestLocalhost
portInt, err := strconv.Atoi(port)
if err != nil {

View File

@@ -5,6 +5,7 @@ import (
"testing"
"databasus-backend/internal/features/backups/backups/backuping"
"databasus-backend/internal/features/restores/restoring"
cache_utils "databasus-backend/internal/util/cache"
)
@@ -12,11 +13,15 @@ func TestMain(m *testing.M) {
cache_utils.ClearAllCache()
backuperNode := backuping.CreateTestBackuperNode()
cancel := backuping.StartBackuperNodeForTest(&testing.T{}, backuperNode)
cancelBackup := backuping.StartBackuperNodeForTest(&testing.T{}, backuperNode)
restorerNode := restoring.CreateTestRestorerNode()
cancelRestore := restoring.StartRestorerNodeForTest(&testing.T{}, restorerNode)
exitCode := m.Run()
backuping.StopBackuperNodeForTest(&testing.T{}, cancel, backuperNode)
backuping.StopBackuperNodeForTest(&testing.T{}, cancelBackup, backuperNode)
restoring.StopRestorerNodeForTest(&testing.T{}, cancelRestore, restorerNode)
os.Exit(exitCode)
}

View File

@@ -0,0 +1,93 @@
package cache_utils
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func Test_ClearAllCache_AfterClear_CacheIsEmpty(t *testing.T) {
client := getCache()
// Arrange: Set up multiple cache entries with different prefixes
testKeys := []struct {
prefix string
key string
value string
}{
{"test:user:", "user1", "John Doe"},
{"test:user:", "user2", "Jane Smith"},
{"test:session:", "session1", "abc123"},
{"test:session:", "session2", "def456"},
{"test:data:", "item1", "value1"},
}
// Set all test keys
for _, tk := range testKeys {
cacheUtil := NewCacheUtil[string](client, tk.prefix)
cacheUtil.Set(tk.key, &tk.value)
}
// Verify keys were set correctly before clearing
for _, tk := range testKeys {
cacheUtil := NewCacheUtil[string](client, tk.prefix)
retrieved := cacheUtil.Get(tk.key)
assert.NotNil(t, retrieved, "Key %s should exist before clearing", tk.prefix+tk.key)
assert.Equal(t, tk.value, *retrieved, "Retrieved value should match set value")
}
// Act: Clear all cache
err := ClearAllCache()
// Assert: No error returned
assert.NoError(t, err, "ClearAllCache should not return an error")
// Assert: All keys should be deleted
for _, tk := range testKeys {
cacheUtil := NewCacheUtil[string](client, tk.prefix)
retrieved := cacheUtil.Get(tk.key)
assert.Nil(t, retrieved, "Key %s should be deleted after clearing", tk.prefix+tk.key)
}
}
func Test_SetWithExpiration_SetsCorrectTTL(t *testing.T) {
client := getCache()
// Create a cache utility
testPrefix := "test:ttl:"
cacheUtil := NewCacheUtil[string](client, testPrefix)
// Set a value with 1-hour expiration
testKey := "key1"
testValue := "test value"
oneHour := 1 * time.Hour
cacheUtil.SetWithExpiration(testKey, &testValue, oneHour)
// Verify the value was set
retrieved := cacheUtil.Get(testKey)
assert.NotNil(t, retrieved, "Value should be stored")
assert.Equal(t, testValue, *retrieved, "Retrieved value should match")
// Check the TTL using Valkey TTL command
ctx, cancel := context.WithTimeout(context.Background(), DefaultCacheTimeout)
defer cancel()
fullKey := testPrefix + testKey
ttlResult := client.Do(ctx, client.B().Ttl().Key(fullKey).Build())
assert.NoError(t, ttlResult.Error(), "TTL command should not error")
ttlSeconds, err := ttlResult.AsInt64()
assert.NoError(t, err, "TTL should be retrievable as int64")
// TTL should be approximately 1 hour (3600 seconds)
// Allow for a small margin (within 10 seconds of 3600)
expectedTTL := int64(3600)
assert.GreaterOrEqual(t, ttlSeconds, expectedTTL-10, "TTL should be close to 1 hour")
assert.LessOrEqual(t, ttlSeconds, expectedTTL, "TTL should not exceed 1 hour")
// Clean up
cacheUtil.Invalidate(testKey)
}

View File

@@ -67,6 +67,43 @@ func (c *CacheUtil[T]) Set(key string, item *T) {
c.client.Do(ctx, c.client.B().Set().Key(fullKey).Value(string(data)).Ex(c.expiry).Build())
}
func (c *CacheUtil[T]) SetWithExpiration(key string, item *T, expiry time.Duration) {
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
defer cancel()
data, err := json.Marshal(item)
if err != nil {
return
}
fullKey := c.prefix + key
c.client.Do(ctx, c.client.B().Set().Key(fullKey).Value(string(data)).Ex(expiry).Build())
}
func (c *CacheUtil[T]) GetAndDelete(key string) *T {
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
defer cancel()
fullKey := c.prefix + key
result := c.client.Do(ctx, c.client.B().Getdel().Key(fullKey).Build())
if result.Error() != nil {
return nil
}
data, err := result.AsBytes()
if err != nil {
return nil
}
var item T
if err := json.Unmarshal(data, &item); err != nil {
return nil
}
return &item
}
func (c *CacheUtil[T]) Invalidate(key string) {
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
defer cancel()

Some files were not shown because too many files have changed in this diff Show More