mirror of
https://github.com/databasus/databasus.git
synced 2026-04-06 00:32:03 +02:00
Compare commits
24 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
755c420157 | ||
|
|
ff73627287 | ||
|
|
9c9ab00ace | ||
|
|
7366e21a1a | ||
|
|
a327d1aa57 | ||
|
|
f152b16ea3 | ||
|
|
85dbe80d3d | ||
|
|
edf4028fd1 | ||
|
|
8d85c45a90 | ||
|
|
d9c176d19a | ||
|
|
7a6f72a456 | ||
|
|
9a1471b88b | ||
|
|
386ea1d708 | ||
|
|
a4b23936ee | ||
|
|
b36aa9d48b | ||
|
|
13cb8e5bd2 | ||
|
|
2db4b6e075 | ||
|
|
f2b0b2bf1f | ||
|
|
7142ce295e | ||
|
|
04621b9b2d | ||
|
|
bd329a68cf | ||
|
|
f957abc9db | ||
|
|
c0fd6be1a9 | ||
|
|
c39bd34d5e |
219
.github/workflows/ci-release.yml
vendored
219
.github/workflows/ci-release.yml
vendored
@@ -9,15 +9,26 @@ on:
|
||||
|
||||
jobs:
|
||||
lint-backend:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
container:
|
||||
image: golang:1.24.9
|
||||
volumes:
|
||||
- /runner-cache/go-pkg:/go/pkg/mod
|
||||
- /runner-cache/go-build:/root/.cache/go-build
|
||||
- /runner-cache/golangci-lint:/root/.cache/golangci-lint
|
||||
- /runner-cache/apt-archives:/var/cache/apt/archives
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.24.9"
|
||||
- name: Configure Git for container
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Download Go modules
|
||||
run: |
|
||||
cd backend
|
||||
go mod download
|
||||
|
||||
- name: Install golangci-lint
|
||||
run: |
|
||||
@@ -93,34 +104,32 @@ jobs:
|
||||
npm run test
|
||||
|
||||
test-backend:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
needs: [lint-backend]
|
||||
container:
|
||||
image: golang:1.24.9
|
||||
options: --privileged -v /var/run/docker.sock:/var/run/docker.sock --add-host=host.docker.internal:host-gateway
|
||||
volumes:
|
||||
- /runner-cache/go-pkg:/go/pkg/mod
|
||||
- /runner-cache/go-build:/root/.cache/go-build
|
||||
- /runner-cache/apt-archives:/var/cache/apt/archives
|
||||
steps:
|
||||
- name: Free up disk space
|
||||
- name: Install Docker CLI
|
||||
run: |
|
||||
echo "Disk space before cleanup:"
|
||||
df -h
|
||||
# Remove unnecessary pre-installed software
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo rm -rf /usr/local/share/boost
|
||||
sudo rm -rf /usr/share/swift
|
||||
# Clean apt cache
|
||||
sudo apt-get clean
|
||||
# Clean docker images (if any pre-installed)
|
||||
docker system prune -af --volumes || true
|
||||
echo "Disk space after cleanup:"
|
||||
df -h
|
||||
apt-get update -qq
|
||||
apt-get install -y -qq docker.io docker-compose netcat-openbsd wget
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.24.9"
|
||||
- name: Configure Git for container
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Download Go modules
|
||||
run: |
|
||||
cd backend
|
||||
go mod download
|
||||
|
||||
- name: Create .env file for testing
|
||||
run: |
|
||||
@@ -132,14 +141,16 @@ jobs:
|
||||
DEV_DB_PASSWORD=Q1234567
|
||||
#app
|
||||
ENV_MODE=development
|
||||
# db
|
||||
DATABASE_DSN=host=localhost user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
|
||||
DATABASE_URL=postgres://postgres:Q1234567@localhost:5437/databasus?sslmode=disable
|
||||
# db - using 172.17.0.1 to access host from container
|
||||
DATABASE_DSN=host=172.17.0.1 user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
|
||||
DATABASE_URL=postgres://postgres:Q1234567@172.17.0.1:5437/databasus?sslmode=disable
|
||||
# migrations
|
||||
GOOSE_DRIVER=postgres
|
||||
GOOSE_DBSTRING=postgres://postgres:Q1234567@localhost:5437/databasus?sslmode=disable
|
||||
GOOSE_DBSTRING=postgres://postgres:Q1234567@172.17.0.1:5437/databasus?sslmode=disable
|
||||
GOOSE_MIGRATION_DIR=./migrations
|
||||
# testing
|
||||
# testing
|
||||
TEST_LOCALHOST=172.17.0.1
|
||||
IS_SKIP_EXTERNAL_RESOURCES_TESTS=true
|
||||
# to get Google Drive env variables: add storage in UI and copy data from added storage here
|
||||
TEST_GOOGLE_DRIVE_CLIENT_ID=${{ secrets.TEST_GOOGLE_DRIVE_CLIENT_ID }}
|
||||
TEST_GOOGLE_DRIVE_CLIENT_SECRET=${{ secrets.TEST_GOOGLE_DRIVE_CLIENT_SECRET }}
|
||||
@@ -197,12 +208,14 @@ jobs:
|
||||
TEST_MONGODB_60_PORT=27060
|
||||
TEST_MONGODB_70_PORT=27070
|
||||
TEST_MONGODB_82_PORT=27082
|
||||
# Valkey (cache)
|
||||
VALKEY_HOST=localhost
|
||||
# Valkey (cache) - using 172.17.0.1
|
||||
VALKEY_HOST=172.17.0.1
|
||||
VALKEY_PORT=6379
|
||||
VALKEY_USERNAME=
|
||||
VALKEY_PASSWORD=
|
||||
VALKEY_IS_SSL=false
|
||||
# Host for test databases (container -> host)
|
||||
TEST_DB_HOST=172.17.0.1
|
||||
EOF
|
||||
|
||||
- name: Start test containers
|
||||
@@ -220,25 +233,25 @@ jobs:
|
||||
timeout 60 bash -c 'until docker exec dev-valkey valkey-cli ping 2>/dev/null | grep -q PONG; do sleep 2; done'
|
||||
echo "Valkey is ready!"
|
||||
|
||||
# Wait for test databases
|
||||
timeout 60 bash -c 'until nc -z localhost 5000; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5001; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5002; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5003; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5004; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5005; do sleep 2; done'
|
||||
# Wait for test databases (using 172.17.0.1 from container)
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5000; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5001; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5002; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5003; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5004; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 5005; do sleep 2; done'
|
||||
|
||||
# Wait for MinIO
|
||||
timeout 60 bash -c 'until nc -z localhost 9000; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 9000; do sleep 2; done'
|
||||
|
||||
# Wait for Azurite
|
||||
timeout 60 bash -c 'until nc -z localhost 10000; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 10000; do sleep 2; done'
|
||||
|
||||
# Wait for FTP
|
||||
timeout 60 bash -c 'until nc -z localhost 7007; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 7007; do sleep 2; done'
|
||||
|
||||
# Wait for SFTP
|
||||
timeout 60 bash -c 'until nc -z localhost 7008; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z 172.17.0.1 7008; do sleep 2; done'
|
||||
|
||||
# Wait for MySQL containers
|
||||
echo "Waiting for MySQL 5.7..."
|
||||
@@ -297,63 +310,63 @@ jobs:
|
||||
mkdir -p databasus-data/backups
|
||||
mkdir -p databasus-data/temp
|
||||
|
||||
- name: Install MySQL dependencies
|
||||
- name: Install database client dependencies
|
||||
run: |
|
||||
sudo apt-get update -qq
|
||||
sudo apt-get install -y -qq libncurses6
|
||||
sudo ln -sf /usr/lib/x86_64-linux-gnu/libncurses.so.6 /usr/lib/x86_64-linux-gnu/libncurses.so.5
|
||||
sudo ln -sf /usr/lib/x86_64-linux-gnu/libtinfo.so.6 /usr/lib/x86_64-linux-gnu/libtinfo.so.5
|
||||
apt-get update -qq
|
||||
apt-get install -y -qq libncurses6 libpq5
|
||||
ln -sf /usr/lib/x86_64-linux-gnu/libncurses.so.6 /usr/lib/x86_64-linux-gnu/libncurses.so.5 || true
|
||||
ln -sf /usr/lib/x86_64-linux-gnu/libtinfo.so.6 /usr/lib/x86_64-linux-gnu/libtinfo.so.5 || true
|
||||
|
||||
- name: Setup PostgreSQL, MySQL and MariaDB client tools from pre-built assets
|
||||
run: |
|
||||
cd backend/tools
|
||||
|
||||
|
||||
# Create directory structure
|
||||
mkdir -p postgresql mysql mariadb mongodb/bin
|
||||
|
||||
|
||||
# Copy PostgreSQL client tools (12-18) from pre-built assets
|
||||
for version in 12 13 14 15 16 17 18; do
|
||||
mkdir -p postgresql/postgresql-$version
|
||||
cp -r ../../assets/tools/x64/postgresql/postgresql-$version/bin postgresql/postgresql-$version/
|
||||
done
|
||||
|
||||
|
||||
# Copy MySQL client tools (5.7, 8.0, 8.4, 9) from pre-built assets
|
||||
for version in 5.7 8.0 8.4 9; do
|
||||
mkdir -p mysql/mysql-$version
|
||||
cp -r ../../assets/tools/x64/mysql/mysql-$version/bin mysql/mysql-$version/
|
||||
done
|
||||
|
||||
|
||||
# Copy MariaDB client tools (10.6, 12.1) from pre-built assets
|
||||
for version in 10.6 12.1; do
|
||||
mkdir -p mariadb/mariadb-$version
|
||||
cp -r ../../assets/tools/x64/mariadb/mariadb-$version/bin mariadb/mariadb-$version/
|
||||
done
|
||||
|
||||
|
||||
# Make all binaries executable
|
||||
chmod +x postgresql/*/bin/*
|
||||
chmod +x mysql/*/bin/*
|
||||
chmod +x mariadb/*/bin/*
|
||||
|
||||
|
||||
echo "Pre-built client tools setup complete"
|
||||
|
||||
- name: Install MongoDB Database Tools
|
||||
run: |
|
||||
cd backend/tools
|
||||
|
||||
|
||||
# MongoDB Database Tools must be downloaded (not in pre-built assets)
|
||||
# They are backward compatible - single version supports all servers (4.0-8.0)
|
||||
MONGODB_TOOLS_URL="https://fastdl.mongodb.org/tools/db/mongodb-database-tools-debian12-x86_64-100.10.0.deb"
|
||||
|
||||
|
||||
echo "Downloading MongoDB Database Tools..."
|
||||
wget -q "$MONGODB_TOOLS_URL" -O /tmp/mongodb-database-tools.deb
|
||||
|
||||
|
||||
echo "Installing MongoDB Database Tools..."
|
||||
sudo dpkg -i /tmp/mongodb-database-tools.deb || sudo apt-get install -f -y --no-install-recommends
|
||||
|
||||
dpkg -i /tmp/mongodb-database-tools.deb || apt-get install -f -y --no-install-recommends
|
||||
|
||||
# Create symlinks to tools directory
|
||||
ln -sf /usr/bin/mongodump mongodb/bin/mongodump
|
||||
ln -sf /usr/bin/mongorestore mongodb/bin/mongorestore
|
||||
|
||||
|
||||
rm -f /tmp/mongodb-database-tools.deb
|
||||
echo "MongoDB Database Tools installed successfully"
|
||||
|
||||
@@ -401,10 +414,28 @@ jobs:
|
||||
if: always()
|
||||
run: |
|
||||
cd backend
|
||||
# Stop and remove containers (keeping images for next run)
|
||||
docker compose -f docker-compose.yml.example down -v
|
||||
|
||||
# Clean up all data directories created by docker-compose
|
||||
echo "Cleaning up data directories..."
|
||||
rm -rf pgdata || true
|
||||
rm -rf valkey-data || true
|
||||
rm -rf mysqldata || true
|
||||
rm -rf mariadbdata || true
|
||||
rm -rf temp/nas || true
|
||||
rm -rf databasus-data || true
|
||||
|
||||
# Also clean root-level databasus-data if exists
|
||||
cd ..
|
||||
rm -rf databasus-data || true
|
||||
|
||||
echo "Cleanup complete"
|
||||
|
||||
determine-version:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
container:
|
||||
image: node:20
|
||||
needs: [test-backend, test-frontend]
|
||||
if: ${{ github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message, '[skip-release]') }}
|
||||
outputs:
|
||||
@@ -417,10 +448,9 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "20"
|
||||
- name: Configure Git for container
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Install semver
|
||||
run: npm install -g semver
|
||||
@@ -434,6 +464,7 @@ jobs:
|
||||
|
||||
- name: Analyze commits and determine version bump
|
||||
id: version_bump
|
||||
shell: bash
|
||||
run: |
|
||||
CURRENT_VERSION="${{ steps.current_version.outputs.current_version }}"
|
||||
LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
|
||||
@@ -453,7 +484,7 @@ jobs:
|
||||
HAS_FIX=false
|
||||
HAS_BREAKING=false
|
||||
|
||||
# Analyze each commit
|
||||
# Analyze each commit - USE PROCESS SUBSTITUTION to avoid subshell variable scope issues
|
||||
while IFS= read -r commit; do
|
||||
if [[ "$commit" =~ ^FEATURE ]]; then
|
||||
HAS_FEATURE=true
|
||||
@@ -471,7 +502,7 @@ jobs:
|
||||
HAS_BREAKING=true
|
||||
echo "Found BREAKING CHANGE: $commit"
|
||||
fi
|
||||
done <<< "$COMMITS"
|
||||
done < <(printf '%s\n' "$COMMITS")
|
||||
|
||||
# Determine version bump
|
||||
if [ "$HAS_BREAKING" = true ]; then
|
||||
@@ -497,10 +528,15 @@ jobs:
|
||||
fi
|
||||
|
||||
build-only:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
needs: [test-backend, test-frontend]
|
||||
if: ${{ github.ref == 'refs/heads/main' && contains(github.event.head_commit.message, '[skip-release]') }}
|
||||
steps:
|
||||
- name: Clean workspace
|
||||
run: |
|
||||
sudo rm -rf "$GITHUB_WORKSPACE"/* || true
|
||||
sudo rm -rf "$GITHUB_WORKSPACE"/.* || true
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -529,12 +565,17 @@ jobs:
|
||||
databasus/databasus:${{ github.sha }}
|
||||
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
needs: [determine-version]
|
||||
if: ${{ needs.determine-version.outputs.should_release == 'true' }}
|
||||
permissions:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Clean workspace
|
||||
run: |
|
||||
sudo rm -rf "$GITHUB_WORKSPACE"/* || true
|
||||
sudo rm -rf "$GITHUB_WORKSPACE"/.* || true
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -564,21 +605,33 @@ jobs:
|
||||
databasus/databasus:${{ github.sha }}
|
||||
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
container:
|
||||
image: node:20
|
||||
needs: [determine-version, build-and-push]
|
||||
if: ${{ needs.determine-version.outputs.should_release == 'true' }}
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Clean workspace
|
||||
run: |
|
||||
rm -rf "$GITHUB_WORKSPACE"/* || true
|
||||
rm -rf "$GITHUB_WORKSPACE"/.* || true
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Configure Git for container
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Generate changelog
|
||||
id: changelog
|
||||
shell: bash
|
||||
run: |
|
||||
NEW_VERSION="${{ needs.determine-version.outputs.new_version }}"
|
||||
LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
|
||||
@@ -598,6 +651,7 @@ jobs:
|
||||
FIXES=""
|
||||
REFACTORS=""
|
||||
|
||||
# USE PROCESS SUBSTITUTION to avoid subshell variable scope issues
|
||||
while IFS= read -r line; do
|
||||
if [ -n "$line" ]; then
|
||||
COMMIT_MSG=$(echo "$line" | cut -d'|' -f1)
|
||||
@@ -631,7 +685,7 @@ jobs:
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
done <<< "$COMMITS"
|
||||
done < <(printf '%s\n' "$COMMITS")
|
||||
|
||||
# Build changelog sections
|
||||
if [ -n "$FEATURES" ]; then
|
||||
@@ -670,16 +724,33 @@ jobs:
|
||||
prerelease: false
|
||||
|
||||
publish-helm-chart:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
container:
|
||||
image: alpine:3.19
|
||||
volumes:
|
||||
- /runner-cache/apk-cache:/etc/apk/cache
|
||||
needs: [determine-version, build-and-push]
|
||||
if: ${{ needs.determine-version.outputs.should_release == 'true' }}
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
steps:
|
||||
- name: Clean workspace
|
||||
run: |
|
||||
rm -rf "$GITHUB_WORKSPACE"/* || true
|
||||
rm -rf "$GITHUB_WORKSPACE"/.* || true
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apk add --no-cache git bash curl
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Configure Git for container
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@v4
|
||||
with:
|
||||
@@ -701,4 +772,4 @@ jobs:
|
||||
- name: Push Helm chart to GHCR
|
||||
run: |
|
||||
VERSION="${{ needs.determine-version.outputs.new_version }}"
|
||||
helm push databasus-${VERSION}.tgz oci://ghcr.io/databasus/charts
|
||||
helm push databasus-${VERSION}.tgz oci://ghcr.io/databasus/charts
|
||||
|
||||
250
AGENTS.md
250
AGENTS.md
@@ -7,6 +7,7 @@ This is NOT a strict set of rules, but a set of recommendations to help you writ
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Engineering Philosophy](#engineering-philosophy)
|
||||
- [Backend Guidelines](#backend-guidelines)
|
||||
- [Code Style](#code-style)
|
||||
- [Comments](#comments)
|
||||
@@ -22,6 +23,67 @@ This is NOT a strict set of rules, but a set of recommendations to help you writ
|
||||
|
||||
---
|
||||
|
||||
## Engineering Philosophy
|
||||
|
||||
**Think like a skeptical senior engineer and code reviewer. Don't just do what was asked—also think about what should have been asked.**
|
||||
|
||||
⚠️ **Balance vigilance with pragmatism:** Catch real issues, not theoretical ones. Don't let perfect be the enemy of good.
|
||||
|
||||
### Task Context Assessment:
|
||||
|
||||
**First, assess the task scope:**
|
||||
|
||||
- **Trivial** (typos, formatting, simple field adds): Apply directly with minimal analysis
|
||||
- **Standard** (CRUD, typical features): Brief assumption check, proceed
|
||||
- **Complex** (architecture, security, performance-critical): Full analysis required
|
||||
- **Unclear** (ambiguous requirements): Always clarify assumptions first
|
||||
|
||||
### For Non-Trivial Tasks:
|
||||
|
||||
1. **Restate the objective and list assumptions** (explicit + implicit)
|
||||
- If any assumption is shaky, call it out clearly
|
||||
- Distinguish between what's specified and what you're inferring
|
||||
|
||||
2. **Propose appropriate solutions:**
|
||||
- For complex tasks: 2–3 viable approaches (including a simpler baseline)
|
||||
- Recommend one with clear tradeoffs
|
||||
- Consider: complexity, maintainability, performance, future extensibility
|
||||
|
||||
3. **Identify risks proactively:**
|
||||
- Edge cases and boundary conditions
|
||||
- Security/privacy pitfalls
|
||||
- Performance risks and scalability concerns
|
||||
- Operational concerns (deployment, observability, rollback, monitoring)
|
||||
|
||||
4. **Handle ambiguity:**
|
||||
- If requirements are ambiguous, make a reasonable default and proceed
|
||||
- Clearly label your assumptions
|
||||
- Document what would change under alternative assumptions
|
||||
|
||||
5. **Deliver quality:**
|
||||
- Provide a solution that is correct, testable, and maintainable
|
||||
- Include minimal tests or validation steps
|
||||
- Follow project testing philosophy: prefer controller tests over unit tests
|
||||
- Follow all project guidelines from this document
|
||||
|
||||
6. **Self-review before finalizing:**
|
||||
- Ask: "What could go wrong?"
|
||||
- Patch the answer accordingly
|
||||
- Verify edge cases are handled
|
||||
|
||||
### Application Guidelines:
|
||||
|
||||
**Scale your response to the task:**
|
||||
|
||||
- **Trivial changes:** Steps 5-6 only (deliver quality + self-review)
|
||||
- **Standard features:** Steps 1, 5-6 (restate + deliver + review)
|
||||
- **Complex/risky changes:** All steps 1-6
|
||||
- **Ambiguous requests:** Steps 1, 4 mandatory
|
||||
|
||||
**Be proportionally thorough—brief for simple tasks, comprehensive for risky ones. Avoid analysis paralysis.**
|
||||
|
||||
---
|
||||
|
||||
## Backend Guidelines
|
||||
|
||||
### Code Style
|
||||
@@ -175,6 +237,66 @@ func (c *ProjectController) extractProjectID(ctx *gin.Context) uuid.UUID {
|
||||
|
||||
---
|
||||
|
||||
### Boolean Naming
|
||||
|
||||
**Always prefix boolean variables with verbs like `is`, `has`, `was`, `should`, `can`, etc.**
|
||||
|
||||
This makes the code more readable and clearly indicates that the variable represents a true/false state.
|
||||
|
||||
#### Good Examples:
|
||||
|
||||
```go
|
||||
type User struct {
|
||||
IsActive bool
|
||||
IsVerified bool
|
||||
HasAccess bool
|
||||
WasNotified bool
|
||||
}
|
||||
|
||||
type BackupConfig struct {
|
||||
IsEnabled bool
|
||||
ShouldCompress bool
|
||||
CanRetry bool
|
||||
}
|
||||
|
||||
// Variables
|
||||
isInProgress := true
|
||||
wasCompleted := false
|
||||
hasPermission := checkPermissions()
|
||||
```
|
||||
|
||||
#### Bad Examples:
|
||||
|
||||
```go
|
||||
type User struct {
|
||||
Active bool // Should be: IsActive
|
||||
Verified bool // Should be: IsVerified
|
||||
Access bool // Should be: HasAccess
|
||||
}
|
||||
|
||||
type BackupConfig struct {
|
||||
Enabled bool // Should be: IsEnabled
|
||||
Compress bool // Should be: ShouldCompress
|
||||
Retry bool // Should be: CanRetry
|
||||
}
|
||||
|
||||
// Variables
|
||||
inProgress := true // Should be: isInProgress
|
||||
completed := false // Should be: wasCompleted
|
||||
permission := true // Should be: hasPermission
|
||||
```
|
||||
|
||||
#### Common Boolean Prefixes:
|
||||
|
||||
- **is** - current state (IsActive, IsValid, IsEnabled)
|
||||
- **has** - possession or presence (HasAccess, HasPermission, HasError)
|
||||
- **was** - past state (WasCompleted, WasNotified, WasDeleted)
|
||||
- **should** - intention or recommendation (ShouldRetry, ShouldCompress)
|
||||
- **can** - capability or permission (CanRetry, CanDelete, CanEdit)
|
||||
- **will** - future state (WillExpire, WillRetry)
|
||||
|
||||
---
|
||||
|
||||
### Comments
|
||||
|
||||
#### Guidelines
|
||||
@@ -427,6 +549,134 @@ func GetOrderRepository() *repositories.OrderRepository {
|
||||
}
|
||||
```
|
||||
|
||||
#### SetupDependencies() Pattern
|
||||
|
||||
**All `SetupDependencies()` functions must use sync.Once to ensure idempotent execution.**
|
||||
|
||||
This pattern allows `SetupDependencies()` to be safely called multiple times (especially in tests) while ensuring the actual setup logic executes only once.
|
||||
|
||||
**Implementation Pattern:**
|
||||
|
||||
```go
|
||||
package feature
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
// Initialize dependencies here
|
||||
someService.SetDependency(otherService)
|
||||
anotherService.AddListener(listener)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Why This Pattern:**
|
||||
|
||||
- **Tests can call multiple times**: Test setup often calls `SetupDependencies()` multiple times without issues
|
||||
- **Thread-safe**: Works correctly with concurrent calls (nanoseconds or seconds apart)
|
||||
- **Idempotent**: Subsequent calls are safe, only log warning
|
||||
- **No panics**: Does not break tests or production code on multiple calls
|
||||
|
||||
**Key Points:**
|
||||
|
||||
1. Check `isSetup.Load()` **before** calling `Do()` to detect previous executions
|
||||
2. Set `isSetup.Store(true)` **inside** the `Do()` closure after setup completes
|
||||
3. Log warning if already setup (helps identify unnecessary duplicate calls)
|
||||
4. All setup logic must be inside the `Do()` closure
|
||||
|
||||
---
|
||||
|
||||
### Background Services
|
||||
|
||||
**All background service `Run()` methods must panic if called multiple times to prevent corrupted states.**
|
||||
|
||||
Background services run infinite loops and must never be started twice on the same instance. Multiple calls indicate a serious bug that would cause duplicate goroutines, resource leaks, and data corruption.
|
||||
|
||||
**Implementation Pattern:**
|
||||
|
||||
```go
|
||||
package feature
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type BackgroundService struct {
|
||||
// ... existing fields ...
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *BackgroundService) Run(ctx context.Context) {
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
|
||||
// Existing infinite loop logic
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.doWork()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Why Panic Instead of Warning:**
|
||||
|
||||
- **Prevents corruption**: Multiple `Run()` calls would create duplicate goroutines consuming resources
|
||||
- **Fails fast**: Catches critical bugs immediately in tests and production
|
||||
- **Clear indication**: Panic clearly indicates a serious programming error
|
||||
- **Applies everywhere**: Same protection in tests and production
|
||||
|
||||
**When This Applies:**
|
||||
|
||||
- All background services with infinite loops
|
||||
- Registry services (BackupNodesRegistry, RestoreNodesRegistry)
|
||||
- Scheduler services (BackupsScheduler, RestoresScheduler)
|
||||
- Worker nodes (BackuperNode, RestorerNode)
|
||||
- Cleanup services (AuditLogBackgroundService, DownloadTokenBackgroundService)
|
||||
|
||||
**Key Points:**
|
||||
|
||||
1. Check `hasRun.Load()` **before** calling `Do()` to detect previous executions
|
||||
2. Set `hasRun.Store(true)` **inside** the `Do()` closure before starting work
|
||||
3. **Always panic** if already run (never just log warning)
|
||||
4. All run logic must be inside the `Do()` closure
|
||||
5. This pattern is **thread-safe** for any timing (concurrent or sequential calls)
|
||||
|
||||
---
|
||||
|
||||
### Migrations
|
||||
|
||||
@@ -6,6 +6,9 @@ DEV_DB_PASSWORD=Q1234567
|
||||
ENV_MODE=development
|
||||
# logging
|
||||
SHOW_DB_INSTALLATION_VERIFICATION_LOGS=true
|
||||
# tests
|
||||
TEST_LOCALHOST=localhost
|
||||
IS_SKIP_EXTERNAL_RESOURCES_TESTS=false
|
||||
# db
|
||||
DATABASE_DSN=host=dev-db user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
|
||||
DATABASE_URL=postgres://postgres:Q1234567@dev-db:5437/databasus?sslmode=disable
|
||||
|
||||
@@ -25,10 +25,10 @@ import (
|
||||
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/restores"
|
||||
"databasus-backend/internal/features/restores/restoring"
|
||||
"databasus-backend/internal/features/storages"
|
||||
system_healthcheck "databasus-backend/internal/features/system/healthcheck"
|
||||
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
task_registry "databasus-backend/internal/features/tasks/registry"
|
||||
users_controllers "databasus-backend/internal/features/users/controllers"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
@@ -273,7 +273,7 @@ func runBackgroundTasks(log *slog.Logger) {
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "restore background service", func() {
|
||||
restores.GetRestoreBackgroundService().Run(ctx)
|
||||
restoring.GetRestoresScheduler().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "healthcheck attempt background service", func() {
|
||||
@@ -288,21 +288,29 @@ func runBackgroundTasks(log *slog.Logger) {
|
||||
backups_download.GetDownloadTokenBackgroundService().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "task nodes registry background service", func() {
|
||||
task_registry.GetTaskNodesRegistry().Run(ctx)
|
||||
go runWithPanicLogging(log, "backup nodes registry background service", func() {
|
||||
backuping.GetBackupNodesRegistry().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "restore nodes registry background service", func() {
|
||||
restoring.GetRestoreNodesRegistry().Run(ctx)
|
||||
})
|
||||
} else {
|
||||
log.Info("Skipping primary node tasks as not primary node")
|
||||
}
|
||||
|
||||
if config.GetEnv().IsBackupNode {
|
||||
if config.GetEnv().IsProcessingNode {
|
||||
log.Info("Starting backup node background tasks...")
|
||||
|
||||
go runWithPanicLogging(log, "backup node", func() {
|
||||
backuping.GetBackuperNode().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "restore node", func() {
|
||||
restoring.GetRestorerNode().Run(ctx)
|
||||
})
|
||||
} else {
|
||||
log.Info("Skipping backup node tasks as not backup node")
|
||||
log.Info("Skipping backup/restore node tasks as not backup node")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/ilyakaznacheev/cleanenv"
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
@@ -30,12 +29,14 @@ type EnvVariables struct {
|
||||
MariadbInstallDir string `env:"MARIADB_INSTALL_DIR"`
|
||||
MongodbInstallDir string `env:"MONGODB_INSTALL_DIR"`
|
||||
|
||||
ShowDbInstallationVerificationLogs bool `env:"SHOW_DB_INSTALLATION_VERIFICATION_LOGS"`
|
||||
TestLocalhost string `env:"TEST_LOCALHOST"`
|
||||
|
||||
ShowDbInstallationVerificationLogs bool `env:"SHOW_DB_INSTALLATION_VERIFICATION_LOGS"`
|
||||
IsSkipExternalResourcesTests bool `env:"IS_SKIP_EXTERNAL_RESOURCES_TESTS"`
|
||||
|
||||
NodeID string
|
||||
IsManyNodesMode bool `env:"IS_MANY_NODES_MODE"`
|
||||
IsPrimaryNode bool `env:"IS_PRIMARY_NODE"`
|
||||
IsBackupNode bool `env:"IS_BACKUP_NODE"`
|
||||
IsProcessingNode bool `env:"IS_PROCESSING_NODE"`
|
||||
NodeNetworkThroughputMBs int `env:"NODE_NETWORK_THROUGHPUT_MBPS"`
|
||||
|
||||
DataFolder string
|
||||
@@ -176,6 +177,11 @@ func loadEnvVariables() {
|
||||
env.ShowDbInstallationVerificationLogs = true
|
||||
}
|
||||
|
||||
// Set default value for IsSkipExternalTests if not defined
|
||||
if os.Getenv("IS_SKIP_EXTERNAL_RESOURCES_TESTS") == "" {
|
||||
env.IsSkipExternalResourcesTests = false
|
||||
}
|
||||
|
||||
for _, arg := range os.Args {
|
||||
if strings.Contains(arg, "test") {
|
||||
env.IsTesting = true
|
||||
@@ -230,14 +236,17 @@ func loadEnvVariables() {
|
||||
env.ShowDbInstallationVerificationLogs,
|
||||
)
|
||||
|
||||
env.NodeID = uuid.New().String()
|
||||
if env.NodeNetworkThroughputMBs == 0 {
|
||||
env.NodeNetworkThroughputMBs = 125 // 1 Gbit/s
|
||||
}
|
||||
|
||||
if !env.IsManyNodesMode {
|
||||
env.IsPrimaryNode = true
|
||||
env.IsBackupNode = true
|
||||
env.IsProcessingNode = true
|
||||
}
|
||||
|
||||
if env.TestLocalhost == "" {
|
||||
env.TestLocalhost = "localhost"
|
||||
}
|
||||
|
||||
// Valkey
|
||||
|
||||
@@ -2,34 +2,50 @@ package audit_logs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type AuditLogBackgroundService struct {
|
||||
auditLogService *AuditLogService
|
||||
logger *slog.Logger
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *AuditLogBackgroundService) Run(ctx context.Context) {
|
||||
s.logger.Info("Starting audit log cleanup background service")
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
s.logger.Info("Starting audit log cleanup background service")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.cleanOldAuditLogs(); err != nil {
|
||||
s.logger.Error("Failed to clean old audit logs", "error", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.cleanOldAuditLogs(); err != nil {
|
||||
s.logger.Error("Failed to clean old audit logs", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
@@ -14,8 +17,10 @@ var auditLogController = &AuditLogController{
|
||||
auditLogService,
|
||||
}
|
||||
var auditLogBackgroundService = &AuditLogBackgroundService{
|
||||
auditLogService,
|
||||
logger.GetLogger(),
|
||||
auditLogService: auditLogService,
|
||||
logger: logger.GetLogger(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
|
||||
func GetAuditLogService() *AuditLogService {
|
||||
@@ -30,8 +35,23 @@ func GetAuditLogBackgroundService() *AuditLogBackgroundService {
|
||||
return auditLogBackgroundService
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
users_services.GetUserService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
users_services.GetUserService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,23 +2,25 @@ package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/storages"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
task_registry "databasus-backend/internal/features/tasks/registry"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -35,70 +37,85 @@ type BackuperNode struct {
|
||||
storageService *storages.StorageService
|
||||
notificationSender backups_core.NotificationSender
|
||||
backupCancelManager *tasks_cancellation.TaskCancelManager
|
||||
tasksRegistry *task_registry.TaskNodesRegistry
|
||||
backupNodesRegistry *BackupNodesRegistry
|
||||
logger *slog.Logger
|
||||
createBackupUseCase backups_core.CreateBackupUsecase
|
||||
nodeID uuid.UUID
|
||||
|
||||
lastHeartbeat time.Time
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (n *BackuperNode) Run(ctx context.Context) {
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
wasAlreadyRun := n.hasRun.Load()
|
||||
|
||||
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
|
||||
n.runOnce.Do(func() {
|
||||
n.hasRun.Store(true)
|
||||
|
||||
backupNode := task_registry.TaskNode{
|
||||
ID: n.nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
}
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
|
||||
if err := n.tasksRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
|
||||
n.logger.Error("Failed to register node in registry", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
|
||||
|
||||
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
|
||||
n.MakeBackup(backupID, isCallNotifier)
|
||||
if err := n.tasksRegistry.PublishTaskCompletion(n.nodeID.String(), backupID); err != nil {
|
||||
n.logger.Error(
|
||||
"Failed to publish backup completion",
|
||||
"error",
|
||||
err,
|
||||
"backupID",
|
||||
backupID,
|
||||
)
|
||||
backupNode := BackupNode{
|
||||
ID: n.nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: time.Now().UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
if err := n.tasksRegistry.SubscribeNodeForTasksAssignment(n.nodeID.String(), backupHandler); err != nil {
|
||||
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := n.tasksRegistry.UnsubscribeNodeForTasksAssignments(); err != nil {
|
||||
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
|
||||
if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
|
||||
n.logger.Error("Failed to register node in registry", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(heartbeatTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
|
||||
|
||||
if err := n.tasksRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
|
||||
n.logger.Error("Failed to unregister node from registry", "error", err)
|
||||
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
|
||||
n.MakeBackup(backupID, isCallNotifier)
|
||||
if err := n.backupNodesRegistry.PublishBackupCompletion(n.nodeID, backupID); err != nil {
|
||||
n.logger.Error(
|
||||
"Failed to publish backup completion",
|
||||
"error",
|
||||
err,
|
||||
"backupID",
|
||||
backupID,
|
||||
)
|
||||
}
|
||||
|
||||
return
|
||||
case <-ticker.C:
|
||||
n.sendHeartbeat(&backupNode)
|
||||
}
|
||||
|
||||
err := n.backupNodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID, backupHandler)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := n.backupNodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
|
||||
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(heartbeatTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
|
||||
|
||||
if err := n.backupNodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
|
||||
n.logger.Error("Failed to unregister node from registry", "error", err)
|
||||
}
|
||||
|
||||
return
|
||||
case <-ticker.C:
|
||||
n.sendHeartbeat(&backupNode)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", n))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -357,9 +374,9 @@ func (n *BackuperNode) SendBackupNotification(
|
||||
}
|
||||
}
|
||||
|
||||
func (n *BackuperNode) sendHeartbeat(backupNode *task_registry.TaskNode) {
|
||||
func (n *BackuperNode) sendHeartbeat(backupNode *BackupNode) {
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
if err := n.tasksRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *backupNode); err != nil {
|
||||
if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *backupNode); err != nil {
|
||||
n.logger.Error("Failed to send heartbeat", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/usecases"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
@@ -9,57 +14,60 @@ import (
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
task_registry "databasus-backend/internal/features/tasks/registry"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var backupRepository = &backups_core.BackupRepository{}
|
||||
|
||||
var taskCancelManager = tasks_cancellation.GetTaskCancelManager()
|
||||
|
||||
var nodesRegistry = task_registry.GetTaskNodesRegistry()
|
||||
var backupNodesRegistry = &BackupNodesRegistry{
|
||||
client: cache_utils.GetValkeyClient(),
|
||||
logger: logger.GetLogger(),
|
||||
timeout: cache_utils.DefaultCacheTimeout,
|
||||
pubsubBackups: cache_utils.NewPubSubManager(),
|
||||
pubsubCompletions: cache_utils.NewPubSubManager(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
|
||||
func getNodeID() uuid.UUID {
|
||||
nodeIDStr := config.GetEnv().NodeID
|
||||
nodeID, err := uuid.Parse(nodeIDStr)
|
||||
if err != nil {
|
||||
logger.GetLogger().Error("Failed to parse node ID from config", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
return nodeID
|
||||
return uuid.New()
|
||||
}
|
||||
|
||||
var backuperNode = &BackuperNode{
|
||||
databases.GetDatabaseService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
taskCancelManager,
|
||||
nodesRegistry,
|
||||
logger.GetLogger(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
getNodeID(),
|
||||
time.Time{},
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
workspaceService: workspaces_services.GetWorkspaceService(),
|
||||
backupRepository: backupRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
notificationSender: notifiers.GetNotifierService(),
|
||||
backupCancelManager: taskCancelManager,
|
||||
backupNodesRegistry: backupNodesRegistry,
|
||||
logger: logger.GetLogger(),
|
||||
createBackupUseCase: usecases.GetCreateBackupUsecase(),
|
||||
nodeID: getNodeID(),
|
||||
lastHeartbeat: time.Time{},
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
|
||||
var backupsScheduler = &BackupsScheduler{
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
taskCancelManager,
|
||||
nodesRegistry,
|
||||
time.Now().UTC(),
|
||||
logger.GetLogger(),
|
||||
make(map[uuid.UUID]BackupToNodeRelation),
|
||||
backuperNode,
|
||||
backupRepository: backupRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
taskCancelManager: taskCancelManager,
|
||||
backupNodesRegistry: backupNodesRegistry,
|
||||
lastBackupTime: time.Now().UTC(),
|
||||
logger: logger.GetLogger(),
|
||||
backupToNodeRelations: make(map[uuid.UUID]BackupToNodeRelation),
|
||||
backuperNode: backuperNode,
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
|
||||
func GetBackupsScheduler() *BackupsScheduler {
|
||||
@@ -69,3 +77,7 @@ func GetBackupsScheduler() *BackupsScheduler {
|
||||
func GetBackuperNode() *BackuperNode {
|
||||
return backuperNode
|
||||
}
|
||||
|
||||
func GetBackupNodesRegistry() *BackupNodesRegistry {
|
||||
return backupNodesRegistry
|
||||
}
|
||||
|
||||
@@ -1,8 +1,34 @@
|
||||
package backuping
|
||||
|
||||
import "github.com/google/uuid"
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type BackupToNodeRelation struct {
|
||||
NodeID uuid.UUID `json:"nodeId"`
|
||||
BackupsIDs []uuid.UUID `json:"backupsIds"`
|
||||
}
|
||||
|
||||
type BackupNode struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
ThroughputMBs int `json:"throughputMBs"`
|
||||
LastHeartbeat time.Time `json:"lastHeartbeat"`
|
||||
}
|
||||
|
||||
type BackupNodeStats struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
ActiveBackups int `json:"activeBackups"`
|
||||
}
|
||||
|
||||
type BackupSubmitMessage struct {
|
||||
NodeID uuid.UUID `json:"nodeId"`
|
||||
BackupID uuid.UUID `json:"backupId"`
|
||||
IsCallNotifier bool `json:"isCallNotifier"`
|
||||
}
|
||||
|
||||
type BackupCompletionMessage struct {
|
||||
NodeID uuid.UUID `json:"nodeId"`
|
||||
BackupID uuid.UUID `json:"backupId"`
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package task_registry
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
@@ -15,64 +17,73 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
nodeInfoKeyPrefix = "node:"
|
||||
nodeInfoKeySuffix = ":info"
|
||||
nodeActiveTasksPrefix = "node:"
|
||||
nodeActiveTasksSuffix = ":active_tasks"
|
||||
taskSubmitChannel = "task:submit"
|
||||
taskCompletionChannel = "task:completion"
|
||||
nodeInfoKeyPrefix = "backup:node:"
|
||||
nodeInfoKeySuffix = ":info"
|
||||
nodeActiveBackupsPrefix = "backup:node:"
|
||||
nodeActiveBackupsSuffix = ":active_backups"
|
||||
backupSubmitChannel = "backup:submit"
|
||||
backupCompletionChannel = "backup:completion"
|
||||
|
||||
deadNodeThreshold = 2 * time.Minute
|
||||
cleanupTickerInterval = 1 * time.Second
|
||||
)
|
||||
|
||||
// TaskNodesRegistry helps to sync tasks scheduler (backuping or restoring)
|
||||
// and task nodes which are used for network-intensive tasks processing
|
||||
// BackupNodesRegistry helps to sync backups scheduler and backup nodes.
|
||||
//
|
||||
// Features:
|
||||
// - Track node availability and load level
|
||||
// - Assign from scheduler to node tasks needed to be processed
|
||||
// - Notify scheduler from node about task completion
|
||||
// - Assign from scheduler to node backups needed to be processed
|
||||
// - Notify scheduler from node about backup completion
|
||||
//
|
||||
// Important things to remember:
|
||||
// - Node can contain different tasks types so when task is assigned
|
||||
// or node's tasks cleaned - should be performed DB check in DB
|
||||
// that task with this ID exists for this task type at all
|
||||
// - Nodes without heathbeat for more than 2 minutes are not included
|
||||
// - Nodes without heartbeat for more than 2 minutes are not included
|
||||
// in available nodes list and stats
|
||||
//
|
||||
// Cleanup dead nodes performed on 2 levels:
|
||||
// - List and stats functions do not return dead nodes
|
||||
// - Periodically dead nodes are cleaned up in cache (to not
|
||||
// accumulate too many dead nodes in cache)
|
||||
type TaskNodesRegistry struct {
|
||||
type BackupNodesRegistry struct {
|
||||
client valkey.Client
|
||||
logger *slog.Logger
|
||||
timeout time.Duration
|
||||
pubsubTasks *cache_utils.PubSubManager
|
||||
pubsubBackups *cache_utils.PubSubManager
|
||||
pubsubCompletions *cache_utils.PubSubManager
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) Run(ctx context.Context) {
|
||||
if err := r.cleanupDeadNodes(); err != nil {
|
||||
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
|
||||
}
|
||||
func (r *BackupNodesRegistry) Run(ctx context.Context) {
|
||||
wasAlreadyRun := r.hasRun.Load()
|
||||
|
||||
ticker := time.NewTicker(cleanupTickerInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := r.cleanupDeadNodes(); err != nil {
|
||||
r.logger.Error("Failed to cleanup dead nodes", "error", err)
|
||||
r.runOnce.Do(func() {
|
||||
r.hasRun.Store(true)
|
||||
|
||||
if err := r.cleanupDeadNodes(); err != nil {
|
||||
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(cleanupTickerInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := r.cleanupDeadNodes(); err != nil {
|
||||
r.logger.Error("Failed to cleanup dead nodes", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", r))
|
||||
}
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) GetAvailableNodes() ([]TaskNode, error) {
|
||||
func (r *BackupNodesRegistry) GetAvailableNodes() ([]BackupNode, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
@@ -104,7 +115,7 @@ func (r *TaskNodesRegistry) GetAvailableNodes() ([]TaskNode, error) {
|
||||
}
|
||||
|
||||
if len(allKeys) == 0 {
|
||||
return []TaskNode{}, nil
|
||||
return []BackupNode{}, nil
|
||||
}
|
||||
|
||||
keyDataMap, err := r.pipelineGetKeys(allKeys)
|
||||
@@ -113,14 +124,15 @@ func (r *TaskNodesRegistry) GetAvailableNodes() ([]TaskNode, error) {
|
||||
}
|
||||
|
||||
threshold := time.Now().UTC().Add(-deadNodeThreshold)
|
||||
var nodes []TaskNode
|
||||
var nodes []BackupNode
|
||||
|
||||
for key, data := range keyDataMap {
|
||||
// Skip if the key doesn't exist (data is empty)
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var node TaskNode
|
||||
var node BackupNode
|
||||
if err := json.Unmarshal(data, &node); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal node data", "key", key, "error", err)
|
||||
continue
|
||||
@@ -141,13 +153,13 @@ func (r *TaskNodesRegistry) GetAvailableNodes() ([]TaskNode, error) {
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) GetNodesStats() ([]TaskNodeStats, error) {
|
||||
func (r *BackupNodesRegistry) GetBackupNodesStats() ([]BackupNodeStats, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
var allKeys []string
|
||||
cursor := uint64(0)
|
||||
pattern := nodeActiveTasksPrefix + "*" + nodeActiveTasksSuffix
|
||||
pattern := nodeActiveBackupsPrefix + "*" + nodeActiveBackupsSuffix
|
||||
|
||||
for {
|
||||
result := r.client.Do(
|
||||
@@ -156,7 +168,7 @@ func (r *TaskNodesRegistry) GetNodesStats() ([]TaskNodeStats, error) {
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return nil, fmt.Errorf("failed to scan active tasks keys: %w", result.Error())
|
||||
return nil, fmt.Errorf("failed to scan active backups keys: %w", result.Error())
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
@@ -173,18 +185,18 @@ func (r *TaskNodesRegistry) GetNodesStats() ([]TaskNodeStats, error) {
|
||||
}
|
||||
|
||||
if len(allKeys) == 0 {
|
||||
return []TaskNodeStats{}, nil
|
||||
return []BackupNodeStats{}, nil
|
||||
}
|
||||
|
||||
keyDataMap, err := r.pipelineGetKeys(allKeys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pipeline get active tasks keys: %w", err)
|
||||
return nil, fmt.Errorf("failed to pipeline get active backups keys: %w", err)
|
||||
}
|
||||
|
||||
var nodeInfoKeys []string
|
||||
nodeIDToStatsKey := make(map[string]string)
|
||||
for key := range keyDataMap {
|
||||
nodeID := r.extractNodeIDFromKey(key, nodeActiveTasksPrefix, nodeActiveTasksSuffix)
|
||||
nodeID := r.extractNodeIDFromKey(key, nodeActiveBackupsPrefix, nodeActiveBackupsSuffix)
|
||||
nodeIDStr := nodeID.String()
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeIDStr, nodeInfoKeySuffix)
|
||||
nodeInfoKeys = append(nodeInfoKeys, infoKey)
|
||||
@@ -197,14 +209,14 @@ func (r *TaskNodesRegistry) GetNodesStats() ([]TaskNodeStats, error) {
|
||||
}
|
||||
|
||||
threshold := time.Now().UTC().Add(-deadNodeThreshold)
|
||||
var stats []TaskNodeStats
|
||||
var stats []BackupNodeStats
|
||||
for infoKey, nodeData := range nodeInfoMap {
|
||||
// Skip if the info key doesn't exist (nodeData is empty)
|
||||
if len(nodeData) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var node TaskNode
|
||||
var node BackupNode
|
||||
if err := json.Unmarshal(nodeData, &node); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal node data", "key", infoKey, "error", err)
|
||||
continue
|
||||
@@ -223,13 +235,13 @@ func (r *TaskNodesRegistry) GetNodesStats() ([]TaskNodeStats, error) {
|
||||
tasksData := keyDataMap[statsKey]
|
||||
count, err := r.parseIntFromBytes(tasksData)
|
||||
if err != nil {
|
||||
r.logger.Warn("Failed to parse active tasks count", "key", statsKey, "error", err)
|
||||
r.logger.Warn("Failed to parse active backups count", "key", statsKey, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
stat := TaskNodeStats{
|
||||
ID: node.ID,
|
||||
ActiveTasks: int(count),
|
||||
stat := BackupNodeStats{
|
||||
ID: node.ID,
|
||||
ActiveBackups: int(count),
|
||||
}
|
||||
stats = append(stats, stat)
|
||||
}
|
||||
@@ -237,16 +249,16 @@ func (r *TaskNodesRegistry) GetNodesStats() ([]TaskNodeStats, error) {
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) IncrementTasksInProgress(nodeID string) error {
|
||||
func (r *BackupNodesRegistry) IncrementBackupsInProgress(nodeID uuid.UUID) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeActiveTasksPrefix, nodeID, nodeActiveTasksSuffix)
|
||||
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID.String(), nodeActiveBackupsSuffix)
|
||||
result := r.client.Do(ctx, r.client.B().Incr().Key(key).Build())
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to increment tasks in progress for node %s: %w",
|
||||
"failed to increment backups in progress for node %s: %w",
|
||||
nodeID,
|
||||
result.Error(),
|
||||
)
|
||||
@@ -255,16 +267,16 @@ func (r *TaskNodesRegistry) IncrementTasksInProgress(nodeID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) DecrementTasksInProgress(nodeID string) error {
|
||||
func (r *BackupNodesRegistry) DecrementBackupsInProgress(nodeID uuid.UUID) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeActiveTasksPrefix, nodeID, nodeActiveTasksSuffix)
|
||||
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID.String(), nodeActiveBackupsSuffix)
|
||||
result := r.client.Do(ctx, r.client.B().Decr().Key(key).Build())
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to decrement tasks in progress for node %s: %w",
|
||||
"failed to decrement backups in progress for node %s: %w",
|
||||
nodeID,
|
||||
result.Error(),
|
||||
)
|
||||
@@ -279,13 +291,13 @@ func (r *TaskNodesRegistry) DecrementTasksInProgress(nodeID string) error {
|
||||
setCtx, setCancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
r.client.Do(setCtx, r.client.B().Set().Key(key).Value("0").Build())
|
||||
setCancel()
|
||||
r.logger.Warn("Active tasks counter went below 0, reset to 0", "nodeID", nodeID)
|
||||
r.logger.Warn("Active backups counter went below 0, reset to 0", "nodeID", nodeID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) HearthbeatNodeInRegistry(now time.Time, node TaskNode) error {
|
||||
func (r *BackupNodesRegistry) HearthbeatNodeInRegistry(now time.Time, backupNode BackupNode) error {
|
||||
if now.IsZero() {
|
||||
return fmt.Errorf("cannot register node with zero heartbeat timestamp")
|
||||
}
|
||||
@@ -293,36 +305,36 @@ func (r *TaskNodesRegistry) HearthbeatNodeInRegistry(now time.Time, node TaskNod
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
node.LastHeartbeat = now
|
||||
backupNode.LastHeartbeat = now
|
||||
|
||||
data, err := json.Marshal(node)
|
||||
data, err := json.Marshal(backupNode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal node: %w", err)
|
||||
return fmt.Errorf("failed to marshal backup node: %w", err)
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node.ID.String(), nodeInfoKeySuffix)
|
||||
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, backupNode.ID.String(), nodeInfoKeySuffix)
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Set().Key(key).Value(string(data)).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to register node %s: %w", node.ID, result.Error())
|
||||
return fmt.Errorf("failed to register node %s: %w", backupNode.ID, result.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) UnregisterNodeFromRegistry(node TaskNode) error {
|
||||
func (r *BackupNodesRegistry) UnregisterNodeFromRegistry(backupNode BackupNode) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node.ID.String(), nodeInfoKeySuffix)
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, backupNode.ID.String(), nodeInfoKeySuffix)
|
||||
counterKey := fmt.Sprintf(
|
||||
"%s%s%s",
|
||||
nodeActiveTasksPrefix,
|
||||
node.ID.String(),
|
||||
nodeActiveTasksSuffix,
|
||||
nodeActiveBackupsPrefix,
|
||||
backupNode.ID.String(),
|
||||
nodeActiveBackupsSuffix,
|
||||
)
|
||||
|
||||
result := r.client.Do(
|
||||
@@ -331,49 +343,49 @@ func (r *TaskNodesRegistry) UnregisterNodeFromRegistry(node TaskNode) error {
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to unregister node %s: %w", node.ID, result.Error())
|
||||
return fmt.Errorf("failed to unregister node %s: %w", backupNode.ID, result.Error())
|
||||
}
|
||||
|
||||
r.logger.Info("Unregistered node from registry", "nodeID", node.ID)
|
||||
r.logger.Info("Unregistered node from registry", "nodeID", backupNode.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) AssignTaskToNode(
|
||||
targetNodeID string,
|
||||
taskID uuid.UUID,
|
||||
func (r *BackupNodesRegistry) AssignBackupToNode(
|
||||
targetNodeID uuid.UUID,
|
||||
backupID uuid.UUID,
|
||||
isCallNotifier bool,
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
message := TaskSubmitMessage{
|
||||
message := BackupSubmitMessage{
|
||||
NodeID: targetNodeID,
|
||||
TaskID: taskID.String(),
|
||||
BackupID: backupID,
|
||||
IsCallNotifier: isCallNotifier,
|
||||
}
|
||||
|
||||
messageJSON, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal task submit message: %w", err)
|
||||
return fmt.Errorf("failed to marshal backup submit message: %w", err)
|
||||
}
|
||||
|
||||
err = r.pubsubTasks.Publish(ctx, taskSubmitChannel, string(messageJSON))
|
||||
err = r.pubsubBackups.Publish(ctx, backupSubmitChannel, string(messageJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to publish task submit message: %w", err)
|
||||
return fmt.Errorf("failed to publish backup submit message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) SubscribeNodeForTasksAssignment(
|
||||
nodeID string,
|
||||
handler func(taskID uuid.UUID, isCallNotifier bool),
|
||||
func (r *BackupNodesRegistry) SubscribeNodeForBackupsAssignment(
|
||||
nodeID uuid.UUID,
|
||||
handler func(backupID uuid.UUID, isCallNotifier bool),
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
wrappedHandler := func(message string) {
|
||||
var msg TaskSubmitMessage
|
||||
var msg BackupSubmitMessage
|
||||
if err := json.Unmarshal([]byte(message), &msg); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal task submit message", "error", err)
|
||||
r.logger.Warn("Failed to unmarshal backup submit message", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -381,108 +393,84 @@ func (r *TaskNodesRegistry) SubscribeNodeForTasksAssignment(
|
||||
return
|
||||
}
|
||||
|
||||
taskID, err := uuid.Parse(msg.TaskID)
|
||||
if err != nil {
|
||||
r.logger.Warn(
|
||||
"Failed to parse task ID from message",
|
||||
"taskId",
|
||||
msg.TaskID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
handler(taskID, msg.IsCallNotifier)
|
||||
handler(msg.BackupID, msg.IsCallNotifier)
|
||||
}
|
||||
|
||||
err := r.pubsubTasks.Subscribe(ctx, taskSubmitChannel, wrappedHandler)
|
||||
err := r.pubsubBackups.Subscribe(ctx, backupSubmitChannel, wrappedHandler)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to task submit channel: %w", err)
|
||||
return fmt.Errorf("failed to subscribe to backup submit channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Subscribed to task submit channel", "nodeID", nodeID)
|
||||
r.logger.Info("Subscribed to backup submit channel", "nodeID", nodeID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) UnsubscribeNodeForTasksAssignments() error {
|
||||
err := r.pubsubTasks.Close()
|
||||
func (r *BackupNodesRegistry) UnsubscribeNodeForBackupsAssignments() error {
|
||||
err := r.pubsubBackups.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unsubscribe from task submit channel: %w", err)
|
||||
return fmt.Errorf("failed to unsubscribe from backup submit channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Unsubscribed from task submit channel")
|
||||
r.logger.Info("Unsubscribed from backup submit channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) PublishTaskCompletion(nodeID string, taskID uuid.UUID) error {
|
||||
func (r *BackupNodesRegistry) PublishBackupCompletion(nodeID uuid.UUID, backupID uuid.UUID) error {
|
||||
ctx := context.Background()
|
||||
|
||||
message := TaskCompletionMessage{
|
||||
NodeID: nodeID,
|
||||
TaskID: taskID.String(),
|
||||
message := BackupCompletionMessage{
|
||||
NodeID: nodeID,
|
||||
BackupID: backupID,
|
||||
}
|
||||
|
||||
messageJSON, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal task completion message: %w", err)
|
||||
return fmt.Errorf("failed to marshal backup completion message: %w", err)
|
||||
}
|
||||
|
||||
err = r.pubsubCompletions.Publish(ctx, taskCompletionChannel, string(messageJSON))
|
||||
err = r.pubsubCompletions.Publish(ctx, backupCompletionChannel, string(messageJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to publish task completion message: %w", err)
|
||||
return fmt.Errorf("failed to publish backup completion message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) SubscribeForTasksCompletions(
|
||||
handler func(nodeID string, taskID uuid.UUID),
|
||||
func (r *BackupNodesRegistry) SubscribeForBackupsCompletions(
|
||||
handler func(nodeID uuid.UUID, backupID uuid.UUID),
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
wrappedHandler := func(message string) {
|
||||
var msg TaskCompletionMessage
|
||||
var msg BackupCompletionMessage
|
||||
if err := json.Unmarshal([]byte(message), &msg); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal task completion message", "error", err)
|
||||
r.logger.Warn("Failed to unmarshal backup completion message", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
taskID, err := uuid.Parse(msg.TaskID)
|
||||
if err != nil {
|
||||
r.logger.Warn(
|
||||
"Failed to parse task ID from completion message",
|
||||
"taskId",
|
||||
msg.TaskID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
handler(msg.NodeID, taskID)
|
||||
handler(msg.NodeID, msg.BackupID)
|
||||
}
|
||||
|
||||
err := r.pubsubCompletions.Subscribe(ctx, taskCompletionChannel, wrappedHandler)
|
||||
err := r.pubsubCompletions.Subscribe(ctx, backupCompletionChannel, wrappedHandler)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to task completion channel: %w", err)
|
||||
return fmt.Errorf("failed to subscribe to backup completion channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Subscribed to task completion channel")
|
||||
r.logger.Info("Subscribed to backup completion channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) UnsubscribeForTasksCompletions() error {
|
||||
func (r *BackupNodesRegistry) UnsubscribeForBackupsCompletions() error {
|
||||
err := r.pubsubCompletions.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unsubscribe from task completion channel: %w", err)
|
||||
return fmt.Errorf("failed to unsubscribe from backup completion channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Unsubscribed from task completion channel")
|
||||
r.logger.Info("Unsubscribed from backup completion channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) extractNodeIDFromKey(key, prefix, suffix string) uuid.UUID {
|
||||
func (r *BackupNodesRegistry) extractNodeIDFromKey(key, prefix, suffix string) uuid.UUID {
|
||||
nodeIDStr := strings.TrimPrefix(key, prefix)
|
||||
nodeIDStr = strings.TrimSuffix(nodeIDStr, suffix)
|
||||
|
||||
@@ -495,7 +483,7 @@ func (r *TaskNodesRegistry) extractNodeIDFromKey(key, prefix, suffix string) uui
|
||||
return nodeID
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) pipelineGetKeys(keys []string) (map[string][]byte, error) {
|
||||
func (r *BackupNodesRegistry) pipelineGetKeys(keys []string) (map[string][]byte, error) {
|
||||
if len(keys) == 0 {
|
||||
return make(map[string][]byte), nil
|
||||
}
|
||||
@@ -529,7 +517,7 @@ func (r *TaskNodesRegistry) pipelineGetKeys(keys []string) (map[string][]byte, e
|
||||
return keyDataMap, nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) parseIntFromBytes(data []byte) (int64, error) {
|
||||
func (r *BackupNodesRegistry) parseIntFromBytes(data []byte) (int64, error) {
|
||||
str := string(data)
|
||||
var count int64
|
||||
_, err := fmt.Sscanf(str, "%d", &count)
|
||||
@@ -539,7 +527,7 @@ func (r *TaskNodesRegistry) parseIntFromBytes(data []byte) (int64, error) {
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *TaskNodesRegistry) cleanupDeadNodes() error {
|
||||
func (r *BackupNodesRegistry) cleanupDeadNodes() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
@@ -583,13 +571,12 @@ func (r *TaskNodesRegistry) cleanupDeadNodes() error {
|
||||
var deadNodeKeys []string
|
||||
|
||||
for key, data := range keyDataMap {
|
||||
|
||||
// Skip if the key doesn't exist (data is empty)
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var node TaskNode
|
||||
var node BackupNode
|
||||
if err := json.Unmarshal(data, &node); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal node data during cleanup", "key", key, "error", err)
|
||||
continue
|
||||
@@ -603,7 +590,12 @@ func (r *TaskNodesRegistry) cleanupDeadNodes() error {
|
||||
if node.LastHeartbeat.Before(threshold) {
|
||||
nodeID := node.ID.String()
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeID, nodeInfoKeySuffix)
|
||||
statsKey := fmt.Sprintf("%s%s%s", nodeActiveTasksPrefix, nodeID, nodeActiveTasksSuffix)
|
||||
statsKey := fmt.Sprintf(
|
||||
"%s%s%s",
|
||||
nodeActiveBackupsPrefix,
|
||||
nodeID,
|
||||
nodeActiveBackupsSuffix,
|
||||
)
|
||||
|
||||
deadNodeKeys = append(deadNodeKeys, infoKey, statsKey)
|
||||
r.logger.Info(
|
||||
1134
backend/internal/features/backups/backups/backuping/registry_test.go
Normal file
1134
backend/internal/features/backups/backups/backuping/registry_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -2,19 +2,21 @@ package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/storages"
|
||||
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
task_registry "databasus-backend/internal/features/tasks/registry"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/period"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -28,64 +30,79 @@ type BackupsScheduler struct {
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
storageService *storages.StorageService
|
||||
taskCancelManager *task_cancellation.TaskCancelManager
|
||||
tasksRegistry *task_registry.TaskNodesRegistry
|
||||
backupNodesRegistry *BackupNodesRegistry
|
||||
|
||||
lastBackupTime time.Time
|
||||
logger *slog.Logger
|
||||
|
||||
backupToNodeRelations map[uuid.UUID]BackupToNodeRelation
|
||||
backuperNode *BackuperNode
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) Run(ctx context.Context) {
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
|
||||
if config.GetEnv().IsManyNodesMode {
|
||||
// wait other nodes to start
|
||||
time.Sleep(schedulerStartupDelay)
|
||||
}
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
|
||||
if err := s.failBackupsInProgress(); err != nil {
|
||||
s.logger.Error("Failed to fail backups in progress", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
|
||||
if err := s.tasksRegistry.SubscribeForTasksCompletions(s.onBackupCompleted); err != nil {
|
||||
s.logger.Error("Failed to subscribe to backup completions", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := s.tasksRegistry.UnsubscribeForTasksCompletions(); err != nil {
|
||||
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
|
||||
if config.GetEnv().IsManyNodesMode {
|
||||
// wait other nodes to start
|
||||
time.Sleep(schedulerStartupDelay)
|
||||
}
|
||||
}()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if err := s.failBackupsInProgress(); err != nil {
|
||||
s.logger.Error("Failed to fail backups in progress", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(schedulerTickerInterval)
|
||||
defer ticker.Stop()
|
||||
err := s.backupNodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to subscribe to backup completions", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
defer func() {
|
||||
if err := s.backupNodesRegistry.UnsubscribeForBackupsCompletions(); err != nil {
|
||||
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.cleanOldBackups(); err != nil {
|
||||
s.logger.Error("Failed to clean old backups", "error", err)
|
||||
}
|
||||
|
||||
if err := s.checkDeadNodesAndFailBackups(); err != nil {
|
||||
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
|
||||
}
|
||||
|
||||
if err := s.runPendingBackups(); err != nil {
|
||||
s.logger.Error("Failed to run pending backups", "error", err)
|
||||
}
|
||||
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(schedulerTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.cleanOldBackups(); err != nil {
|
||||
s.logger.Error("Failed to clean old backups", "error", err)
|
||||
}
|
||||
|
||||
if err := s.checkDeadNodesAndFailBackups(); err != nil {
|
||||
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
|
||||
}
|
||||
|
||||
if err := s.runPendingBackups(); err != nil {
|
||||
s.logger.Error("Failed to run pending backups", "error", err)
|
||||
}
|
||||
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -180,7 +197,7 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.tasksRegistry.IncrementTasksInProgress(leastBusyNodeID.String()); err != nil {
|
||||
if err := s.backupNodesRegistry.IncrementBackupsInProgress(*leastBusyNodeID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to increment backups in progress",
|
||||
"nodeId",
|
||||
@@ -193,7 +210,7 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.tasksRegistry.AssignTaskToNode(leastBusyNodeID.String(), backup.ID, isCallNotifier); err != nil {
|
||||
if err := s.backupNodesRegistry.AssignBackupToNode(*leastBusyNodeID, backup.ID, isCallNotifier); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to submit backup",
|
||||
"nodeId",
|
||||
@@ -203,7 +220,7 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
if decrementErr := s.tasksRegistry.DecrementTasksInProgress(leastBusyNodeID.String()); decrementErr != nil {
|
||||
if decrementErr := s.backupNodesRegistry.DecrementBackupsInProgress(*leastBusyNodeID); decrementErr != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement backups in progress after submit failure",
|
||||
"nodeId",
|
||||
@@ -398,7 +415,7 @@ func (s *BackupsScheduler) runPendingBackups() error {
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
|
||||
nodes, err := s.tasksRegistry.GetAvailableNodes()
|
||||
nodes, err := s.backupNodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get available nodes: %w", err)
|
||||
}
|
||||
@@ -407,17 +424,17 @@ func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
|
||||
return nil, fmt.Errorf("no nodes available")
|
||||
}
|
||||
|
||||
stats, err := s.tasksRegistry.GetNodesStats()
|
||||
stats, err := s.backupNodesRegistry.GetBackupNodesStats()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get backup nodes stats: %w", err)
|
||||
}
|
||||
|
||||
statsMap := make(map[uuid.UUID]int)
|
||||
for _, stat := range stats {
|
||||
statsMap[stat.ID] = stat.ActiveTasks
|
||||
statsMap[stat.ID] = stat.ActiveBackups
|
||||
}
|
||||
|
||||
var bestNode *task_registry.TaskNode
|
||||
var bestNode *BackupNode
|
||||
var bestScore float64 = -1
|
||||
|
||||
for i := range nodes {
|
||||
@@ -445,21 +462,9 @@ func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
|
||||
return &bestNode.ID, nil
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUID) {
|
||||
nodeID, err := uuid.Parse(nodeIDStr)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to parse node ID from completion message",
|
||||
"nodeId",
|
||||
nodeIDStr,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) onBackupCompleted(nodeID uuid.UUID, backupID uuid.UUID) {
|
||||
// Verify this task is actually a backup (registry contains multiple task types)
|
||||
_, err = s.backupRepository.FindByID(backupID)
|
||||
_, err := s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
// Not a backup task, ignore it
|
||||
return
|
||||
@@ -505,7 +510,7 @@ func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUI
|
||||
s.backupToNodeRelations[nodeID] = relation
|
||||
}
|
||||
|
||||
if err := s.tasksRegistry.DecrementTasksInProgress(nodeIDStr); err != nil {
|
||||
if err := s.backupNodesRegistry.DecrementBackupsInProgress(nodeID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement backups in progress",
|
||||
"nodeId",
|
||||
@@ -519,7 +524,7 @@ func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUI
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) checkDeadNodesAndFailBackups() error {
|
||||
nodes, err := s.tasksRegistry.GetAvailableNodes()
|
||||
nodes, err := s.backupNodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get available nodes: %w", err)
|
||||
}
|
||||
@@ -575,7 +580,7 @@ func (s *BackupsScheduler) checkDeadNodesAndFailBackups() error {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := s.tasksRegistry.DecrementTasksInProgress(nodeID.String()); err != nil {
|
||||
if err := s.backupNodesRegistry.DecrementBackupsInProgress(nodeID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement backups in progress for dead node",
|
||||
"nodeId",
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"databasus-backend/internal/features/intervals"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
task_registry "databasus-backend/internal/features/tasks/registry"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
@@ -466,7 +465,7 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
|
||||
|
||||
// Clean up mock node
|
||||
if mockNodeID != uuid.Nil {
|
||||
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: mockNodeID})
|
||||
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: mockNodeID})
|
||||
}
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
@@ -502,12 +501,12 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
|
||||
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status)
|
||||
|
||||
// Verify Valkey counter was incremented when backup was assigned
|
||||
stats, err := nodesRegistry.GetNodesStats()
|
||||
stats, err := backupNodesRegistry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
foundStat := false
|
||||
for _, stat := range stats {
|
||||
if stat.ID == mockNodeID {
|
||||
assert.Equal(t, 1, stat.ActiveTasks)
|
||||
assert.Equal(t, 1, stat.ActiveBackups)
|
||||
foundStat = true
|
||||
break
|
||||
}
|
||||
@@ -532,11 +531,11 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
|
||||
assert.Contains(t, *backups[0].FailMessage, "node unavailability")
|
||||
|
||||
// Verify Valkey counter was decremented after backup failed
|
||||
stats, err = nodesRegistry.GetNodesStats()
|
||||
stats, err = backupNodesRegistry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
for _, stat := range stats {
|
||||
if stat.ID == mockNodeID {
|
||||
assert.Equal(t, 0, stat.ActiveTasks)
|
||||
assert.Equal(t, 0, stat.ActiveBackups)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -569,7 +568,7 @@ func Test_OnBackupCompleted_WhenTaskIsNotBackup_SkipsProcessing(t *testing.T) {
|
||||
|
||||
// Clean up mock node
|
||||
if mockNodeID != uuid.Nil {
|
||||
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: mockNodeID})
|
||||
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: mockNodeID})
|
||||
}
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
@@ -605,12 +604,12 @@ func Test_OnBackupCompleted_WhenTaskIsNotBackup_SkipsProcessing(t *testing.T) {
|
||||
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status)
|
||||
|
||||
// Get initial state of the registry
|
||||
initialStats, err := nodesRegistry.GetNodesStats()
|
||||
initialStats, err := backupNodesRegistry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
var initialActiveTasks int
|
||||
for _, stat := range initialStats {
|
||||
if stat.ID == mockNodeID {
|
||||
initialActiveTasks = stat.ActiveTasks
|
||||
initialActiveTasks = stat.ActiveBackups
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -618,16 +617,16 @@ func Test_OnBackupCompleted_WhenTaskIsNotBackup_SkipsProcessing(t *testing.T) {
|
||||
|
||||
// Call onBackupCompleted with a random UUID (not a backup ID)
|
||||
nonBackupTaskID := uuid.New()
|
||||
GetBackupsScheduler().onBackupCompleted(mockNodeID.String(), nonBackupTaskID)
|
||||
GetBackupsScheduler().onBackupCompleted(mockNodeID, nonBackupTaskID)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify: Active tasks counter should remain the same (not decremented)
|
||||
stats, err := nodesRegistry.GetNodesStats()
|
||||
stats, err := backupNodesRegistry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
for _, stat := range stats {
|
||||
if stat.ID == mockNodeID {
|
||||
assert.Equal(t, initialActiveTasks, stat.ActiveTasks,
|
||||
assert.Equal(t, initialActiveTasks, stat.ActiveBackups,
|
||||
"Active tasks should not change for non-backup task")
|
||||
}
|
||||
}
|
||||
@@ -658,9 +657,9 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
|
||||
|
||||
defer func() {
|
||||
// Clean up all mock nodes
|
||||
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: node1ID})
|
||||
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: node2ID})
|
||||
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: node3ID})
|
||||
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node1ID})
|
||||
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node2ID})
|
||||
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node3ID})
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
@@ -672,17 +671,17 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
for range 5 {
|
||||
err = nodesRegistry.IncrementTasksInProgress(node1ID.String())
|
||||
err = backupNodesRegistry.IncrementBackupsInProgress(node1ID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
for range 2 {
|
||||
err = nodesRegistry.IncrementTasksInProgress(node2ID.String())
|
||||
err = backupNodesRegistry.IncrementBackupsInProgress(node2ID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
for range 8 {
|
||||
err = nodesRegistry.IncrementTasksInProgress(node3ID.String())
|
||||
err = backupNodesRegistry.IncrementBackupsInProgress(node3ID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -701,8 +700,8 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
|
||||
|
||||
defer func() {
|
||||
// Clean up all mock nodes
|
||||
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: node100MBsID})
|
||||
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: node50MBsID})
|
||||
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node100MBsID})
|
||||
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node50MBsID})
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
@@ -712,11 +711,11 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
for range 10 {
|
||||
err = nodesRegistry.IncrementTasksInProgress(node100MBsID.String())
|
||||
err = backupNodesRegistry.IncrementBackupsInProgress(node100MBsID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
err = nodesRegistry.IncrementTasksInProgress(node50MBsID.String())
|
||||
err = backupNodesRegistry.IncrementBackupsInProgress(node50MBsID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
leastBusyNodeID, err := GetBackupsScheduler().calculateLeastBusyNode()
|
||||
@@ -836,7 +835,8 @@ func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
// Start scheduler so it can handle task completions
|
||||
schedulerCancel := StartSchedulerForTest(t)
|
||||
scheduler := CreateTestScheduler()
|
||||
schedulerCancel := StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
@@ -880,19 +880,19 @@ func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Get initial active task count
|
||||
stats, err := nodesRegistry.GetNodesStats()
|
||||
stats, err := backupNodesRegistry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
var initialActiveTasks int
|
||||
for _, stat := range stats {
|
||||
if stat.ID == backuperNode.nodeID {
|
||||
initialActiveTasks = stat.ActiveTasks
|
||||
initialActiveTasks = stat.ActiveBackups
|
||||
break
|
||||
}
|
||||
}
|
||||
t.Logf("Initial active tasks: %d", initialActiveTasks)
|
||||
|
||||
// Start backup
|
||||
GetBackupsScheduler().StartBackup(database.ID, false)
|
||||
scheduler.StartBackup(database.ID, false)
|
||||
|
||||
// Wait for backup to complete
|
||||
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
|
||||
@@ -913,12 +913,12 @@ func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T
|
||||
assert.True(t, decreased, "Active task count should have decreased after backup completion")
|
||||
|
||||
// Verify final active task count equals initial count
|
||||
finalStats, err := nodesRegistry.GetNodesStats()
|
||||
finalStats, err := backupNodesRegistry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
for _, stat := range finalStats {
|
||||
if stat.ID == backuperNode.nodeID {
|
||||
t.Logf("Final active tasks: %d", stat.ActiveTasks)
|
||||
assert.Equal(t, initialActiveTasks, stat.ActiveTasks,
|
||||
t.Logf("Final active tasks: %d", stat.ActiveBackups)
|
||||
assert.Equal(t, initialActiveTasks, stat.ActiveBackups,
|
||||
"Active task count should return to initial value after backup completion")
|
||||
break
|
||||
}
|
||||
@@ -931,7 +931,8 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
// Start scheduler so it can handle task completions
|
||||
schedulerCancel := StartSchedulerForTest(t)
|
||||
scheduler := CreateTestScheduler()
|
||||
schedulerCancel := StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
@@ -982,19 +983,19 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Get initial active task count
|
||||
stats, err := nodesRegistry.GetNodesStats()
|
||||
stats, err := backupNodesRegistry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
var initialActiveTasks int
|
||||
for _, stat := range stats {
|
||||
if stat.ID == backuperNode.nodeID {
|
||||
initialActiveTasks = stat.ActiveTasks
|
||||
initialActiveTasks = stat.ActiveBackups
|
||||
break
|
||||
}
|
||||
}
|
||||
t.Logf("Initial active tasks: %d", initialActiveTasks)
|
||||
|
||||
// Start backup
|
||||
GetBackupsScheduler().StartBackup(database.ID, false)
|
||||
scheduler.StartBackup(database.ID, false)
|
||||
|
||||
// Wait for backup to fail
|
||||
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
|
||||
@@ -1019,12 +1020,12 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
|
||||
assert.True(t, decreased, "Active task count should have decreased after backup failure")
|
||||
|
||||
// Verify final active task count equals initial count
|
||||
finalStats, err := nodesRegistry.GetNodesStats()
|
||||
finalStats, err := backupNodesRegistry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
for _, stat := range finalStats {
|
||||
if stat.ID == backuperNode.nodeID {
|
||||
t.Logf("Final active tasks: %d", stat.ActiveTasks)
|
||||
assert.Equal(t, initialActiveTasks, stat.ActiveTasks,
|
||||
t.Logf("Final active tasks: %d", stat.ActiveBackups)
|
||||
assert.Equal(t, initialActiveTasks, stat.ActiveBackups,
|
||||
"Active task count should return to initial value after backup failure")
|
||||
break
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package backuping
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -12,7 +14,6 @@ import (
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
task_registry "databasus-backend/internal/features/tasks/registry"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
@@ -36,19 +37,37 @@ func CreateTestRouter() *gin.Engine {
|
||||
|
||||
func CreateTestBackuperNode() *BackuperNode {
|
||||
return &BackuperNode{
|
||||
databases.GetDatabaseService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
workspaceService: workspaces_services.GetWorkspaceService(),
|
||||
backupRepository: backupRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
notificationSender: notifiers.GetNotifierService(),
|
||||
backupCancelManager: taskCancelManager,
|
||||
backupNodesRegistry: backupNodesRegistry,
|
||||
logger: logger.GetLogger(),
|
||||
createBackupUseCase: usecases.GetCreateBackupUsecase(),
|
||||
nodeID: uuid.New(),
|
||||
lastHeartbeat: time.Time{},
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateTestScheduler() *BackupsScheduler {
|
||||
return &BackupsScheduler{
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
taskCancelManager,
|
||||
nodesRegistry,
|
||||
backupNodesRegistry,
|
||||
time.Now().UTC(),
|
||||
logger.GetLogger(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
uuid.New(),
|
||||
time.Time{},
|
||||
make(map[uuid.UUID]BackupToNodeRelation),
|
||||
CreateTestBackuperNode(),
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,7 +133,7 @@ func StartBackuperNodeForTest(t *testing.T, backuperNode *BackuperNode) context.
|
||||
// Poll registry for node presence instead of fixed sleep
|
||||
deadline := time.Now().UTC().Add(5 * time.Second)
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
nodes, err := nodesRegistry.GetAvailableNodes()
|
||||
nodes, err := backupNodesRegistry.GetAvailableNodes()
|
||||
if err == nil {
|
||||
for _, node := range nodes {
|
||||
if node.ID == backuperNode.nodeID {
|
||||
@@ -142,13 +161,13 @@ func StartBackuperNodeForTest(t *testing.T, backuperNode *BackuperNode) context.
|
||||
// StartSchedulerForTest starts the BackupsScheduler in a goroutine for testing.
|
||||
// The scheduler subscribes to task completions and manages backup lifecycle.
|
||||
// Returns a context cancel function that should be deferred to stop the scheduler.
|
||||
func StartSchedulerForTest(t *testing.T) context.CancelFunc {
|
||||
func StartSchedulerForTest(t *testing.T, scheduler *BackupsScheduler) context.CancelFunc {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
GetBackupsScheduler().Run(ctx)
|
||||
scheduler.Run(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -175,7 +194,7 @@ func StopBackuperNodeForTest(t *testing.T, cancel context.CancelFunc, backuperNo
|
||||
// Wait for node to unregister from registry
|
||||
deadline := time.Now().UTC().Add(2 * time.Second)
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
nodes, err := nodesRegistry.GetAvailableNodes()
|
||||
nodes, err := backupNodesRegistry.GetAvailableNodes()
|
||||
if err == nil {
|
||||
found := false
|
||||
for _, node := range nodes {
|
||||
@@ -196,13 +215,13 @@ func StopBackuperNodeForTest(t *testing.T, cancel context.CancelFunc, backuperNo
|
||||
}
|
||||
|
||||
func CreateMockNodeInRegistry(nodeID uuid.UUID, throughputMBs int, lastHeartbeat time.Time) error {
|
||||
backupNode := task_registry.TaskNode{
|
||||
backupNode := BackupNode{
|
||||
ID: nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: lastHeartbeat,
|
||||
}
|
||||
|
||||
return nodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
|
||||
return backupNodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
|
||||
}
|
||||
|
||||
func UpdateNodeHeartbeatDirectly(
|
||||
@@ -210,17 +229,17 @@ func UpdateNodeHeartbeatDirectly(
|
||||
throughputMBs int,
|
||||
lastHeartbeat time.Time,
|
||||
) error {
|
||||
backupNode := task_registry.TaskNode{
|
||||
backupNode := BackupNode{
|
||||
ID: nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: lastHeartbeat,
|
||||
}
|
||||
|
||||
return nodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
|
||||
return backupNodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
|
||||
}
|
||||
|
||||
func GetNodeFromRegistry(nodeID uuid.UUID) (*task_registry.TaskNode, error) {
|
||||
nodes, err := nodesRegistry.GetAvailableNodes()
|
||||
func GetNodeFromRegistry(nodeID uuid.UUID) (*BackupNode, error) {
|
||||
nodes, err := backupNodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -246,7 +265,7 @@ func WaitForActiveTasksDecrease(
|
||||
deadline := time.Now().UTC().Add(timeout)
|
||||
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
stats, err := nodesRegistry.GetNodesStats()
|
||||
stats, err := backupNodesRegistry.GetBackupNodesStats()
|
||||
if err != nil {
|
||||
t.Logf("WaitForActiveTasksDecrease: error getting node stats: %v", err)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
@@ -257,14 +276,14 @@ func WaitForActiveTasksDecrease(
|
||||
if stat.ID == nodeID {
|
||||
t.Logf(
|
||||
"WaitForActiveTasksDecrease: current active tasks = %d (initial = %d)",
|
||||
stat.ActiveTasks,
|
||||
stat.ActiveBackups,
|
||||
initialCount,
|
||||
)
|
||||
if stat.ActiveTasks < initialCount {
|
||||
if stat.ActiveBackups < initialCount {
|
||||
t.Logf(
|
||||
"WaitForActiveTasksDecrease: active tasks decreased from %d to %d",
|
||||
initialCount,
|
||||
stat.ActiveTasks,
|
||||
stat.ActiveBackups,
|
||||
)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -1156,7 +1156,7 @@ func createTestDatabase(
|
||||
Type: databases.DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
@@ -52,11 +55,26 @@ func GetBackupController() *BackupController {
|
||||
return backupController
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
backups_config.
|
||||
GetBackupConfigService().
|
||||
SetDatabaseStorageChangeListener(backupService)
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
databases.GetDatabaseService().AddDbRemoveListener(backupService)
|
||||
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
|
||||
func SetupDependencies() {
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
backups_config.
|
||||
GetBackupConfigService().
|
||||
SetDatabaseStorageChangeListener(backupService)
|
||||
|
||||
databases.GetDatabaseService().AddDbRemoveListener(backupService)
|
||||
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,33 +2,49 @@ package backups_download
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DownloadTokenBackgroundService struct {
|
||||
downloadTokenService *DownloadTokenService
|
||||
logger *slog.Logger
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *DownloadTokenBackgroundService) Run(ctx context.Context) {
|
||||
s.logger.Info("Starting download token cleanup background service")
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
s.logger.Info("Starting download token cleanup background service")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
|
||||
s.logger.Error("Failed to clean expired download tokens", "error", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
|
||||
s.logger.Error("Failed to clean expired download tokens", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/logger"
|
||||
@@ -30,8 +33,10 @@ func init() {
|
||||
}
|
||||
|
||||
downloadTokenBackgroundService = &DownloadTokenBackgroundService{
|
||||
downloadTokenService,
|
||||
logger.GetLogger(),
|
||||
downloadTokenService: downloadTokenService,
|
||||
logger: logger.GetLogger(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -75,3 +75,23 @@ func WaitForBackupCompletion(
|
||||
|
||||
t.Logf("WaitForBackupCompletion: timeout waiting for backup to complete")
|
||||
}
|
||||
|
||||
// CreateTestBackup creates a simple test backup record for testing purposes
|
||||
func CreateTestBackup(databaseID, storageID uuid.UUID) *backups_core.Backup {
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: databaseID,
|
||||
StorageID: storageID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10.5,
|
||||
BackupDurationMs: 1000,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
repo := &backups_core.BackupRepository{}
|
||||
if err := repo.Save(backup); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return backup
|
||||
}
|
||||
|
||||
@@ -1462,7 +1462,7 @@ func createTestDatabaseViaAPI(
|
||||
Type: databases.DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
package backups_config
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var backupConfigRepository = &BackupConfigRepository{}
|
||||
@@ -28,6 +32,21 @@ func GetBackupConfigService() *BackupConfigService {
|
||||
return backupConfigService
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
storages.GetStorageService().SetStorageDatabaseCounter(backupConfigService)
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
storages.GetStorageService().SetStorageDatabaseCounter(backupConfigService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ func getTestPostgresConfig() *postgresql.PostgresqlDatabase {
|
||||
testDbName := "testdb"
|
||||
return &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
@@ -67,7 +67,7 @@ func getTestMariadbConfig() *mariadb.MariadbDatabase {
|
||||
testDbName := "testdb"
|
||||
return &mariadb.MariadbDatabase{
|
||||
Version: tools.MariadbVersion1011,
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
@@ -88,7 +88,7 @@ func getTestMongodbConfig() *mongodb.MongodbDatabase {
|
||||
|
||||
return &mongodb.MongodbDatabase{
|
||||
Version: tools.MongodbVersion7,
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "root",
|
||||
Password: "rootpassword",
|
||||
@@ -829,7 +829,7 @@ func createTestDatabaseViaAPI(
|
||||
Type: DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
|
||||
@@ -515,9 +515,13 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
|
||||
hasProcess := false
|
||||
hasAllPrivileges := false
|
||||
|
||||
// Escape underscores to match MariaDB's grant output format
|
||||
// MariaDB escapes _ as \_ in SHOW GRANTS output
|
||||
// Pattern matches either literal _ or escaped \_
|
||||
escapedDbName := strings.ReplaceAll(regexp.QuoteMeta(database), "_", `(_|\\_)`)
|
||||
dbPatternStr := fmt.Sprintf(
|
||||
`(?i)ON\s+[\x60'"]?%s[\x60'"]?\s*\.\s*\*`,
|
||||
regexp.QuoteMeta(database),
|
||||
escapedDbName,
|
||||
)
|
||||
dbPattern := regexp.MustCompile(dbPatternStr)
|
||||
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\s*\.\s*\*`)
|
||||
|
||||
@@ -694,6 +694,115 @@ func Test_TestConnection_DatabaseWithUnderscores_Success(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_TestConnection_DatabaseWithUnderscoresAndAllPrivileges_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MariadbVersion
|
||||
port string
|
||||
}{
|
||||
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
|
||||
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
|
||||
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
|
||||
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
|
||||
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
|
||||
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
|
||||
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
|
||||
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
|
||||
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
|
||||
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
|
||||
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMariadbContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
underscoreDbName := "test_all_db"
|
||||
|
||||
_, err := container.DB.Exec(
|
||||
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", underscoreDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(
|
||||
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
|
||||
)
|
||||
}()
|
||||
|
||||
underscoreDSN := fmt.Sprintf(
|
||||
"%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
container.Username,
|
||||
container.Password,
|
||||
container.Host,
|
||||
container.Port,
|
||||
underscoreDbName,
|
||||
)
|
||||
underscoreDB, err := sqlx.Connect("mysql", underscoreDSN)
|
||||
assert.NoError(t, err)
|
||||
defer underscoreDB.Close()
|
||||
|
||||
_, err = underscoreDB.Exec(`
|
||||
CREATE TABLE all_priv_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(`INSERT INTO all_priv_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
allPrivUsername := fmt.Sprintf("allpriv%s", uuid.New().String()[:8])
|
||||
allPrivPassword := "allprivpass123"
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
allPrivUsername,
|
||||
allPrivPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"GRANT ALL PRIVILEGES ON `%s`.* TO '%s'@'%%'",
|
||||
underscoreDbName,
|
||||
allPrivUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer dropUserSafe(underscoreDB, allPrivUsername)
|
||||
|
||||
mariadbModel := &MariadbDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: allPrivUsername,
|
||||
Password: allPrivPassword,
|
||||
Database: &underscoreDbName,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mariadbModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, mariadbModel.Privileges)
|
||||
assert.Contains(t, mariadbModel.Privileges, "SELECT")
|
||||
assert.Contains(t, mariadbModel.Privileges, "SHOW VIEW")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type MariadbContainer struct {
|
||||
Host string
|
||||
Port int
|
||||
@@ -714,7 +823,7 @@ func connectToMariadbContainer(
|
||||
}
|
||||
|
||||
dbName := "testdb"
|
||||
host := "127.0.0.1"
|
||||
host := config.GetEnv().TestLocalhost
|
||||
username := "root"
|
||||
password := "rootpassword"
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -397,7 +398,7 @@ func connectToMongodbContainer(
|
||||
}
|
||||
|
||||
dbName := "testdb"
|
||||
host := "127.0.0.1"
|
||||
host := config.GetEnv().TestLocalhost
|
||||
username := "root"
|
||||
password := "rootpassword"
|
||||
authDatabase := "admin"
|
||||
@@ -406,11 +407,18 @@ func connectToMongodbContainer(
|
||||
assert.NoError(t, err)
|
||||
|
||||
uri := fmt.Sprintf(
|
||||
"mongodb://%s:%s@%s:%d/%s?authSource=%s",
|
||||
username, password, host, portInt, dbName, authDatabase,
|
||||
"mongodb://%s:%s@%s:%d/%s?authSource=%s&serverSelectionTimeoutMS=5000&connectTimeoutMS=5000",
|
||||
username,
|
||||
password,
|
||||
host,
|
||||
portInt,
|
||||
dbName,
|
||||
authDatabase,
|
||||
)
|
||||
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
clientOptions := options.Client().ApplyURI(uri)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
if err != nil {
|
||||
|
||||
@@ -489,9 +489,13 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
|
||||
hasProcess := false
|
||||
hasAllPrivileges := false
|
||||
|
||||
// Escape underscores to match MySQL's grant output format
|
||||
// MySQL escapes _ as \_ in SHOW GRANTS output
|
||||
// Pattern matches either literal _ or escaped \_
|
||||
escapedDbName := strings.ReplaceAll(regexp.QuoteMeta(database), "_", `(_|\\_)`)
|
||||
dbPatternStr := fmt.Sprintf(
|
||||
`(?i)ON\s+[\x60'"]?%s[\x60'"]?\s*\.\s*\*`,
|
||||
regexp.QuoteMeta(database),
|
||||
escapedDbName,
|
||||
)
|
||||
dbPattern := regexp.MustCompile(dbPatternStr)
|
||||
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\s*\.\s*\*`)
|
||||
|
||||
@@ -674,6 +674,112 @@ func Test_TestConnection_DatabaseWithUnderscores_Success(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_TestConnection_DatabaseWithUnderscoresAndAllPrivileges_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MysqlVersion
|
||||
port string
|
||||
}{
|
||||
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port},
|
||||
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port},
|
||||
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port},
|
||||
{"MySQL 9", tools.MysqlVersion9, env.TestMysql90Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMysqlContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
underscoreDbName := "test_all_db"
|
||||
|
||||
_, err := container.DB.Exec(
|
||||
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", underscoreDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(
|
||||
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
|
||||
)
|
||||
}()
|
||||
|
||||
underscoreDSN := fmt.Sprintf(
|
||||
"%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
container.Username,
|
||||
container.Password,
|
||||
container.Host,
|
||||
container.Port,
|
||||
underscoreDbName,
|
||||
)
|
||||
underscoreDB, err := sqlx.Connect("mysql", underscoreDSN)
|
||||
assert.NoError(t, err)
|
||||
defer underscoreDB.Close()
|
||||
|
||||
_, err = underscoreDB.Exec(`
|
||||
CREATE TABLE all_priv_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(`INSERT INTO all_priv_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
allPrivUsername := fmt.Sprintf("allpriv_%s", uuid.New().String()[:8])
|
||||
allPrivPassword := "allprivpass123"
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
allPrivUsername,
|
||||
allPrivPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"GRANT ALL PRIVILEGES ON `%s`.* TO '%s'@'%%'",
|
||||
underscoreDbName,
|
||||
allPrivUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(
|
||||
fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", allPrivUsername),
|
||||
)
|
||||
}()
|
||||
|
||||
mysqlModel := &MysqlDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: allPrivUsername,
|
||||
Password: allPrivPassword,
|
||||
Database: &underscoreDbName,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mysqlModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, mysqlModel.Privileges)
|
||||
assert.Contains(t, mysqlModel.Privileges, "SELECT")
|
||||
assert.Contains(t, mysqlModel.Privileges, "SHOW VIEW")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type MysqlContainer struct {
|
||||
Host string
|
||||
Port int
|
||||
@@ -694,7 +800,7 @@ func connectToMysqlContainer(
|
||||
}
|
||||
|
||||
dbName := "testdb"
|
||||
host := "127.0.0.1"
|
||||
host := config.GetEnv().TestLocalhost
|
||||
username := "root"
|
||||
password := "rootpassword"
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -394,10 +395,13 @@ func (p *PostgresqlDatabase) IsUserReadOnly(
|
||||
//
|
||||
// This method performs the following operations atomically in a single transaction:
|
||||
// 1. Creates a PostgreSQL user with a UUID-based password
|
||||
// 2. Grants CONNECT privilege on the database
|
||||
// 3. Grants USAGE on all non-system schemas
|
||||
// 4. Grants SELECT on all existing tables and sequences
|
||||
// 5. Sets default privileges for future tables and sequences
|
||||
// 2. Revokes CREATE privilege on public schema from PUBLIC role
|
||||
// 3. Grants CONNECT privilege on the database
|
||||
// 4. Discovers all user-created schemas
|
||||
// 5. Grants USAGE on all non-system schemas
|
||||
// 6. Grants SELECT on all existing tables and sequences
|
||||
// 7. Sets default privileges for future tables and sequences
|
||||
// 8. Verifies user creation before committing
|
||||
//
|
||||
// Security features:
|
||||
// - Username format: "databasus-{8-char-uuid}" for uniqueness
|
||||
@@ -487,33 +491,56 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
return "", "", fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
// Step 1.5: Revoke CREATE privilege from PUBLIC role on public schema
|
||||
// Step 2: Check if public schema exists and revoke CREATE privilege if it does
|
||||
// This is necessary because all PostgreSQL users inherit CREATE privilege on the
|
||||
// public schema through the PUBLIC role. This is a one-time operation that affects
|
||||
// the entire database, making it more secure by default.
|
||||
// Note: This only affects the public schema; other schemas are unaffected.
|
||||
_, err = tx.Exec(ctx, `REVOKE CREATE ON SCHEMA public FROM PUBLIC`)
|
||||
if err != nil {
|
||||
logger.Error("Failed to revoke CREATE on public from PUBLIC", "error", err)
|
||||
if !strings.Contains(err.Error(), "schema \"public\" does not exist") &&
|
||||
!strings.Contains(err.Error(), "permission denied") {
|
||||
return "", "", fmt.Errorf("failed to revoke CREATE from PUBLIC: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Now revoke from the specific user as well (belt and suspenders)
|
||||
_, err = tx.Exec(ctx, fmt.Sprintf(`REVOKE CREATE ON SCHEMA public FROM "%s"`, baseUsername))
|
||||
if err != nil {
|
||||
logger.Error(
|
||||
"Failed to revoke CREATE on public schema from user",
|
||||
"error",
|
||||
err,
|
||||
"username",
|
||||
baseUsername,
|
||||
var publicSchemaExists bool
|
||||
err = tx.QueryRow(ctx, `
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM information_schema.schemata
|
||||
WHERE schema_name = 'public'
|
||||
)
|
||||
`).Scan(&publicSchemaExists)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to check if public schema exists: %w", err)
|
||||
}
|
||||
|
||||
// Step 2: Grant database connection privilege and revoke TEMP
|
||||
if publicSchemaExists {
|
||||
// Revoke CREATE from PUBLIC role (affects all users)
|
||||
_, err = tx.Exec(ctx, `REVOKE CREATE ON SCHEMA public FROM PUBLIC`)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "permission denied") {
|
||||
logger.Warn(
|
||||
"Failed to revoke CREATE on public from PUBLIC (permission denied)",
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
} else {
|
||||
return "", "", fmt.Errorf("failed to revoke CREATE from PUBLIC on existing public schema: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Now revoke from the specific user as well (belt and suspenders)
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(`REVOKE CREATE ON SCHEMA public FROM "%s"`, baseUsername),
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn(
|
||||
"Failed to revoke CREATE on public schema from user",
|
||||
"error",
|
||||
err,
|
||||
"username",
|
||||
baseUsername,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
logger.Info("Public schema does not exist, skipping CREATE privilege revocation")
|
||||
}
|
||||
|
||||
// Step 3: Grant database connection privilege and revoke TEMP
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(`GRANT CONNECT ON DATABASE "%s" TO "%s"`, *p.Database, baseUsername),
|
||||
@@ -537,7 +564,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
logger.Warn("Failed to revoke TEMP privilege", "error", err, "username", baseUsername)
|
||||
}
|
||||
|
||||
// Step 3: Discover all user-created schemas
|
||||
// Step 4: Discover all user-created schemas
|
||||
rows, err := tx.Query(ctx, `
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
@@ -562,7 +589,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
return "", "", fmt.Errorf("error iterating schemas: %w", err)
|
||||
}
|
||||
|
||||
// Step 4: Grant USAGE on each schema and explicitly prevent CREATE
|
||||
// Step 5: Grant USAGE on each schema and explicitly prevent CREATE
|
||||
for _, schema := range schemas {
|
||||
// Revoke CREATE specifically (handles inheritance from PUBLIC role)
|
||||
_, err = tx.Exec(
|
||||
@@ -591,7 +618,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
}
|
||||
}
|
||||
|
||||
// Step 5: Grant SELECT on ALL existing tables and sequences
|
||||
// Step 6: Grant SELECT on ALL existing tables and sequences
|
||||
grantSelectSQL := fmt.Sprintf(`
|
||||
DO $$
|
||||
DECLARE
|
||||
@@ -613,7 +640,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
return "", "", fmt.Errorf("failed to grant select on tables: %w", err)
|
||||
}
|
||||
|
||||
// Step 6: Set default privileges for FUTURE tables and sequences
|
||||
// Step 7: Set default privileges for FUTURE tables and sequences
|
||||
defaultPrivilegesSQL := fmt.Sprintf(`
|
||||
DO $$
|
||||
DECLARE
|
||||
@@ -635,7 +662,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
return "", "", fmt.Errorf("failed to set default privileges: %w", err)
|
||||
}
|
||||
|
||||
// Step 7: Verify user creation before committing
|
||||
// Step 8: Verify user creation before committing
|
||||
var verifyUsername string
|
||||
err = tx.QueryRow(ctx, fmt.Sprintf(`SELECT rolname FROM pg_roles WHERE rolname = '%s'`, baseUsername)).
|
||||
Scan(&verifyUsername)
|
||||
@@ -851,7 +878,15 @@ func checkBackupPermissions(
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot check SELECT privileges: %w", err)
|
||||
// If the user doesn't have USAGE on the schema, has_table_privilege will fail
|
||||
// with "permission denied for schema". This means they definitely don't have
|
||||
// SELECT privileges, so treat this as missing permissions rather than an error.
|
||||
var pgErr *pgconn.PgError
|
||||
if errors.As(err, &pgErr) && pgErr.Code == "42501" { // insufficient_privilege
|
||||
selectableTableCount = 0
|
||||
} else {
|
||||
return fmt.Errorf("cannot check SELECT privileges: %w", err)
|
||||
}
|
||||
}
|
||||
if selectableTableCount == 0 {
|
||||
missingPrivileges = append(missingPrivileges, "SELECT on tables")
|
||||
|
||||
@@ -599,6 +599,10 @@ func Test_CreateReadOnlyUser_DatabaseNameWithDash_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_Supabase_UserCanReadButNotWrite(t *testing.T) {
|
||||
if config.GetEnv().IsSkipExternalResourcesTests {
|
||||
t.Skip("Skipping Supabase test: IS_SKIP_EXTERNAL_RESOURCES_TESTS is true")
|
||||
}
|
||||
|
||||
env := config.GetEnv()
|
||||
|
||||
if env.TestSupabaseHost == "" {
|
||||
@@ -705,6 +709,344 @@ func Test_CreateReadOnlyUser_Supabase_UserCanReadButNotWrite(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_WithPublicSchema_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
port string
|
||||
}{
|
||||
{"PostgreSQL 12", "12", env.TestPostgres12Port},
|
||||
{"PostgreSQL 13", "13", env.TestPostgres13Port},
|
||||
{"PostgreSQL 14", "14", env.TestPostgres14Port},
|
||||
{"PostgreSQL 15", "15", env.TestPostgres15Port},
|
||||
{"PostgreSQL 16", "16", env.TestPostgres16Port},
|
||||
{"PostgreSQL 17", "17", env.TestPostgres17Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToPostgresContainer(t, tc.port)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`
|
||||
DROP TABLE IF EXISTS public_schema_test CASCADE;
|
||||
CREATE TABLE public_schema_test (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO public_schema_test (data) VALUES ('test1'), ('test2');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, username)
|
||||
assert.NotEmpty(t, password)
|
||||
assert.True(t, strings.HasPrefix(username, "databasus-"))
|
||||
|
||||
readOnlyModel := &PostgresqlDatabase{
|
||||
Version: pgModel.Version,
|
||||
Host: pgModel.Host,
|
||||
Port: pgModel.Port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: pgModel.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
|
||||
ctx,
|
||||
logger,
|
||||
nil,
|
||||
uuid.New(),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, isReadOnly, "User should be read-only")
|
||||
assert.Empty(t, privileges, "Read-only user should have no write privileges")
|
||||
|
||||
readOnlyDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
username,
|
||||
password,
|
||||
container.Database,
|
||||
)
|
||||
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var count int
|
||||
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM public_schema_test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
_, err = readOnlyConn.Exec(
|
||||
"INSERT INTO public_schema_test (data) VALUES ('should-fail')",
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("CREATE TABLE public.hack_table (id INT)")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to drop owned objects: %v", err)
|
||||
}
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_WithoutPublicSchema_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
port string
|
||||
}{
|
||||
{"PostgreSQL 12", "12", env.TestPostgres12Port},
|
||||
{"PostgreSQL 13", "13", env.TestPostgres13Port},
|
||||
{"PostgreSQL 14", "14", env.TestPostgres14Port},
|
||||
{"PostgreSQL 15", "15", env.TestPostgres15Port},
|
||||
{"PostgreSQL 16", "16", env.TestPostgres16Port},
|
||||
{"PostgreSQL 17", "17", env.TestPostgres17Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToPostgresContainer(t, tc.port)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`
|
||||
DROP SCHEMA IF EXISTS public CASCADE;
|
||||
DROP SCHEMA IF EXISTS app_schema CASCADE;
|
||||
DROP SCHEMA IF EXISTS data_schema CASCADE;
|
||||
CREATE SCHEMA app_schema;
|
||||
CREATE SCHEMA data_schema;
|
||||
CREATE TABLE app_schema.users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
username TEXT NOT NULL
|
||||
);
|
||||
CREATE TABLE data_schema.records (
|
||||
id SERIAL PRIMARY KEY,
|
||||
info TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO app_schema.users (username) VALUES ('user1'), ('user2');
|
||||
INSERT INTO data_schema.records (info) VALUES ('record1'), ('record2');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err, "CreateReadOnlyUser should succeed without public schema")
|
||||
assert.NotEmpty(t, username)
|
||||
assert.NotEmpty(t, password)
|
||||
assert.True(t, strings.HasPrefix(username, "databasus-"))
|
||||
|
||||
readOnlyModel := &PostgresqlDatabase{
|
||||
Version: pgModel.Version,
|
||||
Host: pgModel.Host,
|
||||
Port: pgModel.Port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: pgModel.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
|
||||
ctx,
|
||||
logger,
|
||||
nil,
|
||||
uuid.New(),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, isReadOnly, "User should be read-only")
|
||||
assert.Empty(t, privileges, "Read-only user should have no write privileges")
|
||||
|
||||
readOnlyDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
username,
|
||||
password,
|
||||
container.Database,
|
||||
)
|
||||
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var userCount int
|
||||
err = readOnlyConn.Get(&userCount, "SELECT COUNT(*) FROM app_schema.users")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, userCount)
|
||||
|
||||
var recordCount int
|
||||
err = readOnlyConn.Get(&recordCount, "SELECT COUNT(*) FROM data_schema.records")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, recordCount)
|
||||
|
||||
_, err = readOnlyConn.Exec(
|
||||
"INSERT INTO app_schema.users (username) VALUES ('should-fail')",
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("CREATE TABLE app_schema.hack_table (id INT)")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("CREATE TABLE data_schema.hack_table (id INT)")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to drop owned objects: %v", err)
|
||||
}
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`
|
||||
DROP SCHEMA IF EXISTS app_schema CASCADE;
|
||||
DROP SCHEMA IF EXISTS data_schema CASCADE;
|
||||
CREATE SCHEMA IF NOT EXISTS public;
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_PublicSchemaExistsButNoPermissions_ReturnsError(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
port string
|
||||
}{
|
||||
{"PostgreSQL 12", "12", env.TestPostgres12Port},
|
||||
{"PostgreSQL 13", "13", env.TestPostgres13Port},
|
||||
{"PostgreSQL 14", "14", env.TestPostgres14Port},
|
||||
{"PostgreSQL 15", "15", env.TestPostgres15Port},
|
||||
{"PostgreSQL 16", "16", env.TestPostgres16Port},
|
||||
{"PostgreSQL 17", "17", env.TestPostgres17Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToPostgresContainer(t, tc.port)
|
||||
defer container.DB.Close()
|
||||
|
||||
limitedAdminUsername := fmt.Sprintf("limited_admin_%s", uuid.New().String()[:8])
|
||||
limitedAdminPassword := "limited_password_123"
|
||||
|
||||
_, err := container.DB.Exec(`
|
||||
CREATE SCHEMA IF NOT EXISTS public;
|
||||
DROP TABLE IF EXISTS public.permission_test_table CASCADE;
|
||||
CREATE TABLE public.permission_test_table (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO public.permission_test_table (data) VALUES ('test1');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`GRANT CREATE ON SCHEMA public TO PUBLIC`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN CREATEROLE`,
|
||||
limitedAdminUsername,
|
||||
limitedAdminPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
|
||||
container.Database,
|
||||
limitedAdminUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(
|
||||
fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, limitedAdminUsername),
|
||||
)
|
||||
_, _ = container.DB.Exec(
|
||||
fmt.Sprintf(`DROP USER IF EXISTS "%s"`, limitedAdminUsername),
|
||||
)
|
||||
_, _ = container.DB.Exec(`REVOKE CREATE ON SCHEMA public FROM PUBLIC`)
|
||||
}()
|
||||
|
||||
limitedAdminDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
limitedAdminUsername,
|
||||
limitedAdminPassword,
|
||||
container.Database,
|
||||
)
|
||||
limitedAdminConn, err := sqlx.Connect("postgres", limitedAdminDSN)
|
||||
assert.NoError(t, err)
|
||||
defer limitedAdminConn.Close()
|
||||
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Version: tools.GetPostgresqlVersionEnum(tc.version),
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: limitedAdminUsername,
|
||||
Password: limitedAdminPassword,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.Error(
|
||||
t,
|
||||
err,
|
||||
"CreateReadOnlyUser should fail when admin lacks permissions to secure public schema",
|
||||
)
|
||||
if err != nil {
|
||||
errorMsg := err.Error()
|
||||
hasExpectedError := strings.Contains(
|
||||
errorMsg,
|
||||
"failed to revoke CREATE from PUBLIC on existing public schema",
|
||||
) ||
|
||||
strings.Contains(errorMsg, "permission denied for schema public") ||
|
||||
strings.Contains(errorMsg, "failed to grant")
|
||||
assert.True(
|
||||
t,
|
||||
hasExpectedError,
|
||||
"Error should indicate permission issues with public schema, got: %s",
|
||||
errorMsg,
|
||||
)
|
||||
}
|
||||
assert.Empty(t, username)
|
||||
assert.Empty(t, password)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Validate_WhenLocalhostAndDatabasus_ReturnsError(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -981,7 +1323,7 @@ func connectToPostgresContainer(t *testing.T, port string) *PostgresContainer {
|
||||
dbName := "testdb"
|
||||
password := "testpassword"
|
||||
username := "testuser"
|
||||
host := "localhost"
|
||||
host := config.GetEnv().TestLocalhost
|
||||
|
||||
portInt, err := strconv.Atoi(port)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package databases
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
@@ -37,7 +40,22 @@ func GetDatabaseController() *DatabaseController {
|
||||
return databaseController
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(databaseService)
|
||||
notifiers.GetNotifierService().SetNotifierDatabaseCounter(databaseService)
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(databaseService)
|
||||
notifiers.GetNotifierService().SetNotifierDatabaseCounter(databaseService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ func GetTestPostgresConfig() *postgresql.PostgresqlDatabase {
|
||||
testDbName := "testdb"
|
||||
return &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
@@ -48,7 +48,7 @@ func GetTestMariadbConfig() *mariadb.MariadbDatabase {
|
||||
testDbName := "testdb"
|
||||
return &mariadb.MariadbDatabase{
|
||||
Version: tools.MariadbVersion1011,
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
@@ -69,7 +69,7 @@ func GetTestMongodbConfig() *mongodb.MongodbDatabase {
|
||||
|
||||
return &mongodb.MongodbDatabase{
|
||||
Version: tools.MongodbVersion7,
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "root",
|
||||
Password: "rootpassword",
|
||||
|
||||
@@ -2,30 +2,47 @@ package healthcheck_attempt
|
||||
|
||||
import (
|
||||
"context"
|
||||
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
|
||||
)
|
||||
|
||||
type HealthcheckAttemptBackgroundService struct {
|
||||
healthcheckConfigService *healthcheck_config.HealthcheckConfigService
|
||||
checkDatabaseHealthUseCase *CheckDatabaseHealthUseCase
|
||||
logger *slog.Logger
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *HealthcheckAttemptBackgroundService) Run(ctx context.Context) {
|
||||
// first healthcheck immediately
|
||||
s.checkDatabases()
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.checkDatabases()
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
|
||||
// first healthcheck immediately
|
||||
s.checkDatabases()
|
||||
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.checkDatabases()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package healthcheck_attempt
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"databasus-backend/internal/features/databases"
|
||||
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
@@ -22,9 +25,11 @@ var checkDatabaseHealthUseCase = &CheckDatabaseHealthUseCase{
|
||||
}
|
||||
|
||||
var healthcheckAttemptBackgroundService = &HealthcheckAttemptBackgroundService{
|
||||
healthcheck_config.GetHealthcheckConfigService(),
|
||||
checkDatabaseHealthUseCase,
|
||||
logger.GetLogger(),
|
||||
healthcheckConfigService: healthcheck_config.GetHealthcheckConfigService(),
|
||||
checkDatabaseHealthUseCase: checkDatabaseHealthUseCase,
|
||||
logger: logger.GetLogger(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
var healthcheckAttemptController = &HealthcheckAttemptController{
|
||||
healthcheckAttemptService,
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package healthcheck_config
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/databases"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
@@ -27,8 +30,23 @@ func GetHealthcheckConfigController() *HealthcheckConfigController {
|
||||
return healthcheckConfigController
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
databases.
|
||||
GetDatabaseService().
|
||||
AddDbCreationListener(healthcheckConfigService)
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
databases.
|
||||
GetDatabaseService().
|
||||
AddDbCreationListener(healthcheckConfigService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package notifiers
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
@@ -32,6 +35,22 @@ func GetNotifierService() *NotifierService {
|
||||
func GetNotifierRepository() *NotifierRepository {
|
||||
return notifierRepository
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(notifierService)
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(notifierService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"mime"
|
||||
"net"
|
||||
"net/smtp"
|
||||
"time"
|
||||
@@ -115,16 +116,34 @@ func (e *EmailNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeRFC2047 encodes a string using RFC 2047 MIME encoding for email headers
|
||||
// This ensures compatibility with SMTP servers that don't support SMTPUTF8
|
||||
func encodeRFC2047(s string) string {
|
||||
// mime.QEncoding handles UTF-8 → =?UTF-8?Q?...?= encoding
|
||||
// This allows non-ASCII characters (emojis, accents, etc.) in email headers
|
||||
// while maintaining compatibility with all SMTP servers
|
||||
return mime.QEncoding.Encode("UTF-8", s)
|
||||
}
|
||||
|
||||
func (e *EmailNotifier) buildEmailContent(heading, message, from string) []byte {
|
||||
subject := fmt.Sprintf("Subject: %s\r\n", heading)
|
||||
mime := fmt.Sprintf(
|
||||
// Encode Subject header using RFC 2047 to avoid SMTPUTF8 requirement
|
||||
// This ensures compatibility with SMTP servers that don't support SMTPUTF8
|
||||
encodedSubject := encodeRFC2047(heading)
|
||||
subject := fmt.Sprintf("Subject: %s\r\n", encodedSubject)
|
||||
|
||||
mimeHeaders := fmt.Sprintf(
|
||||
"MIME-version: 1.0;\nContent-Type: %s; charset=\"%s\";\n\n",
|
||||
MIMETypeHTML,
|
||||
MIMECharsetUTF8,
|
||||
)
|
||||
fromHeader := fmt.Sprintf("From: %s\r\n", from)
|
||||
|
||||
// Encode From header display name if it contains non-ASCII
|
||||
encodedFrom := encodeRFC2047(from)
|
||||
fromHeader := fmt.Sprintf("From: %s\r\n", encodedFrom)
|
||||
|
||||
toHeader := fmt.Sprintf("To: %s\r\n", e.TargetEmail)
|
||||
return []byte(fromHeader + toHeader + subject + mime + message)
|
||||
|
||||
return []byte(fromHeader + toHeader + subject + mimeHeaders + message)
|
||||
}
|
||||
|
||||
func (e *EmailNotifier) sendImplicitTLS(
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
package restores
|
||||
|
||||
import (
|
||||
"context"
|
||||
"databasus-backend/internal/features/restores/enums"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
type RestoreBackgroundService struct {
|
||||
restoreRepository *RestoreRepository
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *RestoreBackgroundService) Run(ctx context.Context) {
|
||||
if err := s.failRestoresInProgress(); err != nil {
|
||||
s.logger.Error("Failed to fail restores in progress", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RestoreBackgroundService) failRestoresInProgress() error {
|
||||
restoresInProgress, err := s.restoreRepository.FindByStatus(enums.RestoreStatusInProgress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, restore := range restoresInProgress {
|
||||
failMessage := "Restore failed due to application restart"
|
||||
restore.Status = enums.RestoreStatusFailed
|
||||
restore.FailMessage = &failMessage
|
||||
|
||||
if err := s.restoreRepository.Save(restore); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package restores
|
||||
|
||||
import (
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
"net/http"
|
||||
|
||||
@@ -15,6 +16,7 @@ type RestoreController struct {
|
||||
func (c *RestoreController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
router.GET("/restores/:backupId", c.GetRestores)
|
||||
router.POST("/restores/:backupId/restore", c.RestoreBackup)
|
||||
router.POST("/restores/cancel/:restoreId", c.CancelRestore)
|
||||
}
|
||||
|
||||
// GetRestores
|
||||
@@ -23,7 +25,7 @@ func (c *RestoreController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// @Tags restores
|
||||
// @Produce json
|
||||
// @Param backupId path string true "Backup ID"
|
||||
// @Success 200 {array} models.Restore
|
||||
// @Success 200 {array} restores_core.Restore
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Router /restores/{backupId} [get]
|
||||
@@ -71,7 +73,7 @@ func (c *RestoreController) RestoreBackup(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var requestDTO RestoreBackupRequest
|
||||
var requestDTO restores_core.RestoreBackupRequest
|
||||
if err := ctx.ShouldBindJSON(&requestDTO); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -84,3 +86,33 @@ func (c *RestoreController) RestoreBackup(ctx *gin.Context) {
|
||||
|
||||
ctx.JSON(http.StatusOK, gin.H{"message": "restore started successfully"})
|
||||
}
|
||||
|
||||
// CancelRestore
|
||||
// @Summary Cancel an in-progress restore
|
||||
// @Description Cancel a restore that is currently in progress
|
||||
// @Tags restores
|
||||
// @Param restoreId path string true "Restore ID"
|
||||
// @Success 204
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Router /restores/cancel/{restoreId} [post]
|
||||
func (c *RestoreController) CancelRestore(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
restoreID, err := uuid.Parse(ctx.Param("restoreId"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid restore ID"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.restoreService.CancelRestore(user, restoreID); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
env_config "databasus-backend/internal/config"
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
@@ -24,16 +24,19 @@ import (
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/databases/databases/mysql"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
"databasus-backend/internal/features/restores/models"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/restores/restoring"
|
||||
"databasus-backend/internal/features/storages"
|
||||
local_storage "databasus-backend/internal/features/storages/models/local"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
users_dto "databasus-backend/internal/features/users/dto"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_models "databasus-backend/internal/features/workspaces/models"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
test_utils "databasus-backend/internal/util/testing"
|
||||
"databasus-backend/internal/util/tools"
|
||||
@@ -46,7 +49,7 @@ func Test_GetRestores_WhenUserIsWorkspaceMember_RestoresReturned(t *testing.T) {
|
||||
|
||||
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
|
||||
var restores []*models.Restore
|
||||
var restores []*restores_core.Restore
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
@@ -90,7 +93,7 @@ func Test_GetRestores_WhenUserIsGlobalAdmin_RestoresReturned(t *testing.T) {
|
||||
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
|
||||
var restores []*models.Restore
|
||||
var restores []*restores_core.Restore
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
@@ -105,15 +108,19 @@ func Test_GetRestores_WhenUserIsGlobalAdmin_RestoresReturned(t *testing.T) {
|
||||
|
||||
func Test_RestoreBackup_WhenUserIsWorkspaceMember_RestoreInitiated(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
|
||||
_, cleanup := SetupMockRestoreNode(t)
|
||||
defer cleanup()
|
||||
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
|
||||
request := RestoreBackupRequest{
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: env_config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
@@ -141,10 +148,10 @@ func Test_RestoreBackup_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing
|
||||
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
|
||||
request := RestoreBackupRequest{
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: env_config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
@@ -165,15 +172,19 @@ func Test_RestoreBackup_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing
|
||||
|
||||
func Test_RestoreBackup_WithIsExcludeExtensions_FlagPassedCorrectly(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
|
||||
_, cleanup := SetupMockRestoreNode(t)
|
||||
defer cleanup()
|
||||
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
|
||||
request := RestoreBackupRequest{
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: env_config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
@@ -195,15 +206,19 @@ func Test_RestoreBackup_WithIsExcludeExtensions_FlagPassedCorrectly(t *testing.T
|
||||
|
||||
func Test_RestoreBackup_AuditLogWritten(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
|
||||
_, cleanup := SetupMockRestoreNode(t)
|
||||
defer cleanup()
|
||||
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
|
||||
request := RestoreBackupRequest{
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: env_config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
@@ -272,18 +287,25 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
|
||||
// Setup mock node for tests that skip disk validation and reach scheduler
|
||||
if !tc.expectDiskValidated {
|
||||
_, cleanup := SetupMockRestoreNode(t)
|
||||
defer cleanup()
|
||||
}
|
||||
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
var backup *backups_core.Backup
|
||||
var request RestoreBackupRequest
|
||||
var request restores_core.RestoreBackupRequest
|
||||
|
||||
if tc.dbType == databases.DatabaseTypePostgres {
|
||||
_, backup = createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
request = RestoreBackupRequest{
|
||||
request = restores_core.RestoreBackupRequest{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Host: env_config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
@@ -310,10 +332,10 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
backup = createTestBackup(mysqlDB, owner)
|
||||
request = RestoreBackupRequest{
|
||||
request = restores_core.RestoreBackupRequest{
|
||||
MysqlDatabase: &mysql.MysqlDatabase{
|
||||
Version: tools.MysqlVersion80,
|
||||
Host: "localhost",
|
||||
Host: env_config.GetEnv().TestLocalhost,
|
||||
Port: 3306,
|
||||
Username: "root",
|
||||
Password: "password",
|
||||
@@ -353,16 +375,187 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
router := workspaces_testing.CreateTestRouter(
|
||||
workspaces_controllers.GetWorkspaceController(),
|
||||
workspaces_controllers.GetMembershipController(),
|
||||
databases.GetDatabaseController(),
|
||||
backups_config.GetBackupConfigController(),
|
||||
backups.GetBackupController(),
|
||||
GetRestoreController(),
|
||||
func Test_CancelRestore_InProgressRestore_SuccessfullyCancelled(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
tasks_cancellation.SetupDependencies()
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := createTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backupRepo := backups_core.BackupRepository{}
|
||||
backups, _ := backupRepo.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
restoreRepo := restores_core.RestoreRepository{}
|
||||
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
restores, _ = restoreRepo.FindByStatus(restores_core.RestoreStatusCanceled)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
||||
|
||||
mockUsecase := &restoring.MockBlockingRestoreUsecase{
|
||||
StartedChan: make(chan bool, 1),
|
||||
}
|
||||
restorerNode := restoring.CreateTestRestorerNodeWithUsecase(mockUsecase)
|
||||
|
||||
cancelNode := restoring.StartRestorerNodeForTest(t, restorerNode)
|
||||
defer cancelNode()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
restoreRequest := restores_core.RestoreBackupRequest{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: env_config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
},
|
||||
}
|
||||
|
||||
var restoreResponse map[string]interface{}
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
|
||||
"Bearer "+user.Token,
|
||||
restoreRequest,
|
||||
http.StatusOK,
|
||||
&restoreResponse,
|
||||
)
|
||||
return router
|
||||
|
||||
select {
|
||||
case <-mockUsecase.StartedChan:
|
||||
t.Log("Restore started and is blocking")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Restore did not start within timeout")
|
||||
}
|
||||
|
||||
restoreRepo := &restores_core.RestoreRepository{}
|
||||
restores, err := restoreRepo.FindByBackupID(backup.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(t, len(restores), 0, "At least one restore should exist")
|
||||
|
||||
var restoreID uuid.UUID
|
||||
for _, r := range restores {
|
||||
if r.Status == restores_core.RestoreStatusInProgress {
|
||||
restoreID = r.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.NotEqual(t, uuid.Nil, restoreID, "Should find an in-progress restore")
|
||||
|
||||
resp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/restores/cancel/%s", restoreID.String()),
|
||||
"Bearer "+user.Token,
|
||||
nil,
|
||||
http.StatusNoContent,
|
||||
)
|
||||
|
||||
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
|
||||
|
||||
deadline := time.Now().UTC().Add(3 * time.Second)
|
||||
var restore *restores_core.Restore
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
restore, err = restoreRepo.FindByID(restoreID)
|
||||
assert.NoError(t, err)
|
||||
if restore.Status == restores_core.RestoreStatusCanceled {
|
||||
break
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
assert.Equal(t, restores_core.RestoreStatusCanceled, restore.Status)
|
||||
|
||||
auditLogService := audit_logs.GetAuditLogService()
|
||||
auditLogs, err := auditLogService.GetWorkspaceAuditLogs(
|
||||
workspace.ID,
|
||||
&audit_logs.GetAuditLogsRequest{Limit: 100, Offset: 0},
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
foundCancelLog := false
|
||||
for _, log := range auditLogs.AuditLogs {
|
||||
if strings.Contains(log.Message, "Restore cancelled") &&
|
||||
strings.Contains(log.Message, database.Name) {
|
||||
foundCancelLog = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, foundCancelLog, "Cancel audit log should be created")
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_RestoreBackup_WithParallelRestoreInProgress_ReturnsError(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
|
||||
_, cleanup := SetupMockRestoreNode(t)
|
||||
defer cleanup()
|
||||
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: env_config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
},
|
||||
}
|
||||
|
||||
testResp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusOK,
|
||||
)
|
||||
assert.Contains(t, string(testResp.Body), "restore started successfully")
|
||||
|
||||
testResp2 := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp2.Body), "another restore is already in progress")
|
||||
}
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
return CreateTestRouter()
|
||||
}
|
||||
|
||||
func createTestDatabaseWithBackupForRestore(
|
||||
@@ -433,7 +626,7 @@ func createTestMySQLDatabase(
|
||||
token string,
|
||||
router *gin.Engine,
|
||||
) *databases.Database {
|
||||
env := config.GetEnv()
|
||||
env := env_config.GetEnv()
|
||||
portStr := env.TestMysql80Port
|
||||
if portStr == "" {
|
||||
portStr = "33080"
|
||||
@@ -451,7 +644,7 @@ func createTestMySQLDatabase(
|
||||
Type: databases.DatabaseTypeMysql,
|
||||
Mysql: &mysql.MysqlDatabase{
|
||||
Version: tools.MysqlVersion80,
|
||||
Host: "localhost",
|
||||
Host: env_config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package restores
|
||||
package restores_core
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/features/databases/databases/mariadb"
|
||||
@@ -1,4 +1,4 @@
|
||||
package enums
|
||||
package restores_core
|
||||
|
||||
type RestoreStatus string
|
||||
|
||||
@@ -6,4 +6,5 @@ const (
|
||||
RestoreStatusInProgress RestoreStatus = "IN_PROGRESS"
|
||||
RestoreStatusCompleted RestoreStatus = "COMPLETED"
|
||||
RestoreStatusFailed RestoreStatus = "FAILED"
|
||||
RestoreStatusCanceled RestoreStatus = "CANCELED"
|
||||
)
|
||||
23
backend/internal/features/restores/core/interfaces.go
Normal file
23
backend/internal/features/restores/core/interfaces.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package restores_core
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/storages"
|
||||
)
|
||||
|
||||
type RestoreBackupUsecase interface {
|
||||
Execute(
|
||||
ctx context.Context,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore Restore,
|
||||
originalDB *databases.Database,
|
||||
restoringToDB *databases.Database,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
isExcludeExtensions bool,
|
||||
) error
|
||||
}
|
||||
30
backend/internal/features/restores/core/model.go
Normal file
30
backend/internal/features/restores/core/model.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package restores_core
|
||||
|
||||
import (
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/databases/databases/mariadb"
|
||||
"databasus-backend/internal/features/databases/databases/mongodb"
|
||||
"databasus-backend/internal/features/databases/databases/mysql"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type Restore struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
|
||||
Status RestoreStatus `json:"status" gorm:"column:status;type:text;not null"`
|
||||
|
||||
BackupID uuid.UUID `json:"backupId" gorm:"column:backup_id;type:uuid;not null"`
|
||||
Backup *backups_core.Backup
|
||||
|
||||
PostgresqlDatabase *postgresql.PostgresqlDatabase `json:"postgresqlDatabase" gorm:"-"`
|
||||
MysqlDatabase *mysql.MysqlDatabase `json:"mysqlDatabase" gorm:"-"`
|
||||
MariadbDatabase *mariadb.MariadbDatabase `json:"mariadbDatabase" gorm:"-"`
|
||||
MongodbDatabase *mongodb.MongodbDatabase `json:"mongodbDatabase" gorm:"-"`
|
||||
|
||||
FailMessage *string `json:"failMessage" gorm:"column:fail_message"`
|
||||
|
||||
RestoreDurationMs int64 `json:"restoreDurationMs" gorm:"column:restore_duration_ms;default:0"`
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;default:now()"`
|
||||
}
|
||||
91
backend/internal/features/restores/core/repository.go
Normal file
91
backend/internal/features/restores/core/repository.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package restores_core
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/storage"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type RestoreRepository struct{}
|
||||
|
||||
func (r *RestoreRepository) Save(restore *Restore) error {
|
||||
db := storage.GetDb()
|
||||
|
||||
isNew := restore.ID == uuid.Nil
|
||||
if isNew {
|
||||
restore.ID = uuid.New()
|
||||
return db.Create(restore).
|
||||
Omit("Backup", "PostgresqlDatabase", "MysqlDatabase", "MariadbDatabase", "MongodbDatabase").
|
||||
Error
|
||||
}
|
||||
|
||||
return db.Save(restore).
|
||||
Omit("Backup", "PostgresqlDatabase", "MysqlDatabase", "MariadbDatabase", "MongodbDatabase").
|
||||
Error
|
||||
}
|
||||
|
||||
func (r *RestoreRepository) FindByBackupID(backupID uuid.UUID) ([]*Restore, error) {
|
||||
var restores []*Restore
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Backup").
|
||||
Where("backup_id = ?", backupID).
|
||||
Order("created_at DESC").
|
||||
Find(&restores).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return restores, nil
|
||||
}
|
||||
|
||||
func (r *RestoreRepository) FindByID(id uuid.UUID) (*Restore, error) {
|
||||
var restore Restore
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Backup").
|
||||
Where("id = ?", id).
|
||||
First(&restore).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &restore, nil
|
||||
}
|
||||
|
||||
func (r *RestoreRepository) FindByStatus(status RestoreStatus) ([]*Restore, error) {
|
||||
var restores []*Restore
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Backup").
|
||||
Where("status = ?", status).
|
||||
Order("created_at DESC").
|
||||
Find(&restores).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return restores, nil
|
||||
}
|
||||
|
||||
func (r *RestoreRepository) FindInProgressRestoresByDatabaseID(
|
||||
databaseID uuid.UUID,
|
||||
) ([]*Restore, error) {
|
||||
var restores []*Restore
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Backup").
|
||||
Joins("JOIN backups ON backups.id = restores.backup_id").
|
||||
Where("backups.database_id = ? AND restores.status = ?", databaseID, RestoreStatusInProgress).
|
||||
Order("restores.created_at DESC").
|
||||
Find(&restores).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return restores, nil
|
||||
}
|
||||
|
||||
func (r *RestoreRepository) DeleteByID(id uuid.UUID) error {
|
||||
return storage.GetDb().Delete(&Restore{}, "id = ?", id).Error
|
||||
}
|
||||
@@ -1,19 +1,24 @@
|
||||
package restores
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/disk"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/restores/usecases"
|
||||
"databasus-backend/internal/features/storages"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var restoreRepository = &RestoreRepository{}
|
||||
var restoreRepository = &restores_core.RestoreRepository{}
|
||||
var restoreService = &RestoreService{
|
||||
backups.GetBackupService(),
|
||||
restoreRepository,
|
||||
@@ -26,24 +31,31 @@ var restoreService = &RestoreService{
|
||||
audit_logs.GetAuditLogService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
disk.GetDiskService(),
|
||||
tasks_cancellation.GetTaskCancelManager(),
|
||||
}
|
||||
var restoreController = &RestoreController{
|
||||
restoreService,
|
||||
}
|
||||
|
||||
var restoreBackgroundService = &RestoreBackgroundService{
|
||||
restoreRepository,
|
||||
logger.GetLogger(),
|
||||
}
|
||||
|
||||
func GetRestoreController() *RestoreController {
|
||||
return restoreController
|
||||
}
|
||||
|
||||
func GetRestoreBackgroundService() *RestoreBackgroundService {
|
||||
return restoreBackgroundService
|
||||
}
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
backups.GetBackupService().AddBackupRemoveListener(restoreService)
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
backups.GetBackupService().AddBackupRemoveListener(restoreService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/restores/enums"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type Restore struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
|
||||
Status enums.RestoreStatus `json:"status" gorm:"column:status;type:text;not null"`
|
||||
|
||||
BackupID uuid.UUID `json:"backupId" gorm:"column:backup_id;type:uuid;not null"`
|
||||
Backup *backups_core.Backup
|
||||
|
||||
FailMessage *string `json:"failMessage" gorm:"column:fail_message"`
|
||||
|
||||
RestoreDurationMs int64 `json:"restoreDurationMs" gorm:"column:restore_duration_ms;default:0"`
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;default:now()"`
|
||||
}
|
||||
@@ -1,75 +0,0 @@
|
||||
package restores
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/features/restores/enums"
|
||||
"databasus-backend/internal/features/restores/models"
|
||||
"databasus-backend/internal/storage"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type RestoreRepository struct{}
|
||||
|
||||
func (r *RestoreRepository) Save(restore *models.Restore) error {
|
||||
db := storage.GetDb()
|
||||
|
||||
isNew := restore.ID == uuid.Nil
|
||||
if isNew {
|
||||
restore.ID = uuid.New()
|
||||
return db.Create(restore).
|
||||
Omit("Backup").
|
||||
Error
|
||||
}
|
||||
|
||||
return db.Save(restore).
|
||||
Omit("Backup").
|
||||
Error
|
||||
}
|
||||
|
||||
func (r *RestoreRepository) FindByBackupID(backupID uuid.UUID) ([]*models.Restore, error) {
|
||||
var restores []*models.Restore
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Backup").
|
||||
Where("backup_id = ?", backupID).
|
||||
Order("created_at DESC").
|
||||
Find(&restores).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return restores, nil
|
||||
}
|
||||
|
||||
func (r *RestoreRepository) FindByID(id uuid.UUID) (*models.Restore, error) {
|
||||
var restore models.Restore
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Backup").
|
||||
Where("id = ?", id).
|
||||
First(&restore).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &restore, nil
|
||||
}
|
||||
|
||||
func (r *RestoreRepository) FindByStatus(status enums.RestoreStatus) ([]*models.Restore, error) {
|
||||
var restores []*models.Restore
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Backup").
|
||||
Where("status = ?", status).
|
||||
Order("created_at DESC").
|
||||
Find(&restores).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return restores, nil
|
||||
}
|
||||
|
||||
func (r *RestoreRepository) DeleteByID(id uuid.UUID) error {
|
||||
return storage.GetDb().Delete(&models.Restore{}, "id = ?", id).Error
|
||||
}
|
||||
85
backend/internal/features/restores/restoring/di.go
Normal file
85
backend/internal/features/restores/restoring/di.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package restoring
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/restores/usecases"
|
||||
"databasus-backend/internal/features/storages"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var restoreRepository = &restores_core.RestoreRepository{}
|
||||
|
||||
var restoreNodesRegistry = &RestoreNodesRegistry{
|
||||
client: cache_utils.GetValkeyClient(),
|
||||
logger: logger.GetLogger(),
|
||||
timeout: cache_utils.DefaultCacheTimeout,
|
||||
pubsubRestores: cache_utils.NewPubSubManager(),
|
||||
pubsubCompletions: cache_utils.NewPubSubManager(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
|
||||
var restoreDatabaseCache = cache_utils.NewCacheUtil[RestoreDatabaseCache](
|
||||
cache_utils.GetValkeyClient(),
|
||||
"restore_db:",
|
||||
)
|
||||
|
||||
var restoreCancelManager = tasks_cancellation.GetTaskCancelManager()
|
||||
|
||||
var restorerNode = &RestorerNode{
|
||||
nodeID: uuid.New(),
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
backupService: backups.GetBackupService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
restoreRepository: restoreRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
restoreNodesRegistry: restoreNodesRegistry,
|
||||
logger: logger.GetLogger(),
|
||||
restoreBackupUsecase: usecases.GetRestoreBackupUsecase(),
|
||||
cacheUtil: restoreDatabaseCache,
|
||||
restoreCancelManager: restoreCancelManager,
|
||||
lastHeartbeat: time.Time{},
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
|
||||
var restoresScheduler = &RestoresScheduler{
|
||||
restoreRepository: restoreRepository,
|
||||
backupService: backups.GetBackupService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
restoreNodesRegistry: restoreNodesRegistry,
|
||||
lastCheckTime: time.Now().UTC(),
|
||||
logger: logger.GetLogger(),
|
||||
restoreToNodeRelations: make(map[uuid.UUID]RestoreToNodeRelation),
|
||||
restorerNode: restorerNode,
|
||||
cacheUtil: restoreDatabaseCache,
|
||||
completionSubscriptionID: uuid.Nil,
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
|
||||
func GetRestoresScheduler() *RestoresScheduler {
|
||||
return restoresScheduler
|
||||
}
|
||||
|
||||
func GetRestorerNode() *RestorerNode {
|
||||
return restorerNode
|
||||
}
|
||||
|
||||
func GetRestoreNodesRegistry() *RestoreNodesRegistry {
|
||||
return restoreNodesRegistry
|
||||
}
|
||||
45
backend/internal/features/restores/restoring/dto.go
Normal file
45
backend/internal/features/restores/restoring/dto.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package restoring
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/features/databases/databases/mariadb"
|
||||
"databasus-backend/internal/features/databases/databases/mongodb"
|
||||
"databasus-backend/internal/features/databases/databases/mysql"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type RestoreDatabaseCache struct {
|
||||
PostgresqlDatabase *postgresql.PostgresqlDatabase `json:"postgresqlDatabase,omitempty"`
|
||||
MysqlDatabase *mysql.MysqlDatabase `json:"mysqlDatabase,omitempty"`
|
||||
MariadbDatabase *mariadb.MariadbDatabase `json:"mariadbDatabase,omitempty"`
|
||||
MongodbDatabase *mongodb.MongodbDatabase `json:"mongodbDatabase,omitempty"`
|
||||
}
|
||||
|
||||
type RestoreToNodeRelation struct {
|
||||
NodeID uuid.UUID `json:"nodeId"`
|
||||
RestoreIDs []uuid.UUID `json:"restoreIds"`
|
||||
}
|
||||
|
||||
type RestoreNode struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
ThroughputMBs int `json:"throughputMBs"`
|
||||
LastHeartbeat time.Time `json:"lastHeartbeat"`
|
||||
}
|
||||
|
||||
type RestoreNodeStats struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
ActiveRestores int `json:"activeRestores"`
|
||||
}
|
||||
|
||||
type RestoreSubmitMessage struct {
|
||||
NodeID uuid.UUID `json:"nodeId"`
|
||||
RestoreID uuid.UUID `json:"restoreId"`
|
||||
IsCallNotifier bool `json:"isCallNotifier"`
|
||||
}
|
||||
|
||||
type RestoreCompletionMessage struct {
|
||||
NodeID uuid.UUID `json:"nodeId"`
|
||||
RestoreID uuid.UUID `json:"restoreId"`
|
||||
}
|
||||
88
backend/internal/features/restores/restoring/mocks.go
Normal file
88
backend/internal/features/restores/restoring/mocks.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package restoring
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
)
|
||||
|
||||
type MockSuccessRestoreUsecase struct{}
|
||||
|
||||
func (uc *MockSuccessRestoreUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore restores_core.Restore,
|
||||
originalDB *databases.Database,
|
||||
restoringToDB *databases.Database,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
isExcludeExtensions bool,
|
||||
) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type MockFailedRestoreUsecase struct{}
|
||||
|
||||
func (uc *MockFailedRestoreUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore restores_core.Restore,
|
||||
originalDB *databases.Database,
|
||||
restoringToDB *databases.Database,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
isExcludeExtensions bool,
|
||||
) error {
|
||||
return errors.New("restore failed")
|
||||
}
|
||||
|
||||
type MockCaptureCredentialsRestoreUsecase struct {
|
||||
CalledChan chan *databases.Database
|
||||
ShouldSucceed bool
|
||||
}
|
||||
|
||||
func (uc *MockCaptureCredentialsRestoreUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore restores_core.Restore,
|
||||
originalDB *databases.Database,
|
||||
restoringToDB *databases.Database,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
isExcludeExtensions bool,
|
||||
) error {
|
||||
uc.CalledChan <- restoringToDB
|
||||
|
||||
if uc.ShouldSucceed {
|
||||
return nil
|
||||
}
|
||||
return errors.New("mock restore failed")
|
||||
}
|
||||
|
||||
type MockBlockingRestoreUsecase struct {
|
||||
StartedChan chan bool
|
||||
}
|
||||
|
||||
func (uc *MockBlockingRestoreUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore restores_core.Restore,
|
||||
originalDB *databases.Database,
|
||||
restoringToDB *databases.Database,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
isExcludeExtensions bool,
|
||||
) error {
|
||||
if uc.StartedChan != nil {
|
||||
uc.StartedChan <- true
|
||||
}
|
||||
|
||||
<-ctx.Done()
|
||||
|
||||
return ctx.Err()
|
||||
}
|
||||
649
backend/internal/features/restores/restoring/registry.go
Normal file
649
backend/internal/features/restores/restoring/registry.go
Normal file
@@ -0,0 +1,649 @@
|
||||
package restoring
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/valkey-io/valkey-go"
|
||||
)
|
||||
|
||||
const (
|
||||
nodeInfoKeyPrefix = "restore:node:"
|
||||
nodeInfoKeySuffix = ":info"
|
||||
nodeActiveRestoresPrefix = "restore:node:"
|
||||
nodeActiveRestoresSuffix = ":active_restores"
|
||||
restoreSubmitChannel = "restore:submit"
|
||||
restoreCompletionChannel = "restore:completion"
|
||||
|
||||
deadNodeThreshold = 2 * time.Minute
|
||||
cleanupTickerInterval = 1 * time.Second
|
||||
)
|
||||
|
||||
// RestoreNodesRegistry helps to sync restores scheduler and restore nodes.
|
||||
//
|
||||
// Features:
|
||||
// - Track node availability and load level
|
||||
// - Assign from scheduler to node restores needed to be processed
|
||||
// - Notify scheduler from node about restore completion
|
||||
//
|
||||
// Important things to remember:
|
||||
// - Nodes without heartbeat for more than 2 minutes are not included
|
||||
// in available nodes list and stats
|
||||
//
|
||||
// Cleanup dead nodes performed on 2 levels:
|
||||
// - List and stats functions do not return dead nodes
|
||||
// - Periodically dead nodes are cleaned up in cache (to not
|
||||
// accumulate too many dead nodes in cache)
|
||||
type RestoreNodesRegistry struct {
|
||||
client valkey.Client
|
||||
logger *slog.Logger
|
||||
timeout time.Duration
|
||||
pubsubRestores *cache_utils.PubSubManager
|
||||
pubsubCompletions *cache_utils.PubSubManager
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) Run(ctx context.Context) {
|
||||
wasAlreadyRun := r.hasRun.Load()
|
||||
|
||||
r.runOnce.Do(func() {
|
||||
r.hasRun.Store(true)
|
||||
|
||||
if err := r.cleanupDeadNodes(); err != nil {
|
||||
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(cleanupTickerInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := r.cleanupDeadNodes(); err != nil {
|
||||
r.logger.Error("Failed to cleanup dead nodes", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", r))
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) GetAvailableNodes() ([]RestoreNode, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
var allKeys []string
|
||||
cursor := uint64(0)
|
||||
pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix
|
||||
|
||||
for {
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return nil, fmt.Errorf("failed to scan node keys: %w", result.Error())
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse scan result: %w", err)
|
||||
}
|
||||
|
||||
allKeys = append(allKeys, scanResult.Elements...)
|
||||
|
||||
cursor = scanResult.Cursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(allKeys) == 0 {
|
||||
return []RestoreNode{}, nil
|
||||
}
|
||||
|
||||
keyDataMap, err := r.pipelineGetKeys(allKeys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pipeline get node keys: %w", err)
|
||||
}
|
||||
|
||||
threshold := time.Now().UTC().Add(-deadNodeThreshold)
|
||||
var nodes []RestoreNode
|
||||
|
||||
for key, data := range keyDataMap {
|
||||
// Skip if the key doesn't exist (data is empty)
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var node RestoreNode
|
||||
if err := json.Unmarshal(data, &node); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal node data", "key", key, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip nodes with zero/uninitialized heartbeat
|
||||
if node.LastHeartbeat.IsZero() {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.LastHeartbeat.Before(threshold) {
|
||||
continue
|
||||
}
|
||||
|
||||
nodes = append(nodes, node)
|
||||
}
|
||||
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) GetRestoreNodesStats() ([]RestoreNodeStats, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
var allKeys []string
|
||||
cursor := uint64(0)
|
||||
pattern := nodeActiveRestoresPrefix + "*" + nodeActiveRestoresSuffix
|
||||
|
||||
for {
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(100).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return nil, fmt.Errorf("failed to scan active restores keys: %w", result.Error())
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse scan result: %w", err)
|
||||
}
|
||||
|
||||
allKeys = append(allKeys, scanResult.Elements...)
|
||||
|
||||
cursor = scanResult.Cursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(allKeys) == 0 {
|
||||
return []RestoreNodeStats{}, nil
|
||||
}
|
||||
|
||||
keyDataMap, err := r.pipelineGetKeys(allKeys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pipeline get active restores keys: %w", err)
|
||||
}
|
||||
|
||||
var nodeInfoKeys []string
|
||||
nodeIDToStatsKey := make(map[string]string)
|
||||
for key := range keyDataMap {
|
||||
nodeID := r.extractNodeIDFromKey(key, nodeActiveRestoresPrefix, nodeActiveRestoresSuffix)
|
||||
nodeIDStr := nodeID.String()
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeIDStr, nodeInfoKeySuffix)
|
||||
nodeInfoKeys = append(nodeInfoKeys, infoKey)
|
||||
nodeIDToStatsKey[infoKey] = key
|
||||
}
|
||||
|
||||
nodeInfoMap, err := r.pipelineGetKeys(nodeInfoKeys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pipeline get node info keys: %w", err)
|
||||
}
|
||||
|
||||
threshold := time.Now().UTC().Add(-deadNodeThreshold)
|
||||
var stats []RestoreNodeStats
|
||||
for infoKey, nodeData := range nodeInfoMap {
|
||||
// Skip if the info key doesn't exist (nodeData is empty)
|
||||
if len(nodeData) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var node RestoreNode
|
||||
if err := json.Unmarshal(nodeData, &node); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal node data", "key", infoKey, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip nodes with zero/uninitialized heartbeat
|
||||
if node.LastHeartbeat.IsZero() {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.LastHeartbeat.Before(threshold) {
|
||||
continue
|
||||
}
|
||||
|
||||
statsKey := nodeIDToStatsKey[infoKey]
|
||||
tasksData := keyDataMap[statsKey]
|
||||
count, err := r.parseIntFromBytes(tasksData)
|
||||
if err != nil {
|
||||
r.logger.Warn("Failed to parse active restores count", "key", statsKey, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
stat := RestoreNodeStats{
|
||||
ID: node.ID,
|
||||
ActiveRestores: int(count),
|
||||
}
|
||||
stats = append(stats, stat)
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) IncrementRestoresInProgress(nodeID uuid.UUID) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf(
|
||||
"%s%s%s",
|
||||
nodeActiveRestoresPrefix,
|
||||
nodeID.String(),
|
||||
nodeActiveRestoresSuffix,
|
||||
)
|
||||
result := r.client.Do(ctx, r.client.B().Incr().Key(key).Build())
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to increment restores in progress for node %s: %w",
|
||||
nodeID,
|
||||
result.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) DecrementRestoresInProgress(nodeID uuid.UUID) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf(
|
||||
"%s%s%s",
|
||||
nodeActiveRestoresPrefix,
|
||||
nodeID.String(),
|
||||
nodeActiveRestoresSuffix,
|
||||
)
|
||||
result := r.client.Do(ctx, r.client.B().Decr().Key(key).Build())
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to decrement restores in progress for node %s: %w",
|
||||
nodeID,
|
||||
result.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
newValue, err := result.AsInt64()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse decremented value for node %s: %w", nodeID, err)
|
||||
}
|
||||
|
||||
if newValue < 0 {
|
||||
setCtx, setCancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
r.client.Do(setCtx, r.client.B().Set().Key(key).Value("0").Build())
|
||||
setCancel()
|
||||
r.logger.Warn("Active restores counter went below 0, reset to 0", "nodeID", nodeID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) HearthbeatNodeInRegistry(
|
||||
now time.Time,
|
||||
restoreNode RestoreNode,
|
||||
) error {
|
||||
if now.IsZero() {
|
||||
return fmt.Errorf("cannot register node with zero heartbeat timestamp")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
restoreNode.LastHeartbeat = now
|
||||
|
||||
data, err := json.Marshal(restoreNode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal restore node: %w", err)
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, restoreNode.ID.String(), nodeInfoKeySuffix)
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Set().Key(key).Value(string(data)).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to register node %s: %w", restoreNode.ID, result.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) UnregisterNodeFromRegistry(restoreNode RestoreNode) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, restoreNode.ID.String(), nodeInfoKeySuffix)
|
||||
counterKey := fmt.Sprintf(
|
||||
"%s%s%s",
|
||||
nodeActiveRestoresPrefix,
|
||||
restoreNode.ID.String(),
|
||||
nodeActiveRestoresSuffix,
|
||||
)
|
||||
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Del().Key(infoKey, counterKey).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to unregister node %s: %w", restoreNode.ID, result.Error())
|
||||
}
|
||||
|
||||
r.logger.Info("Unregistered node from registry", "nodeID", restoreNode.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) AssignRestoreToNode(
|
||||
targetNodeID uuid.UUID,
|
||||
restoreID uuid.UUID,
|
||||
isCallNotifier bool,
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
message := RestoreSubmitMessage{
|
||||
NodeID: targetNodeID,
|
||||
RestoreID: restoreID,
|
||||
IsCallNotifier: isCallNotifier,
|
||||
}
|
||||
|
||||
messageJSON, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal restore submit message: %w", err)
|
||||
}
|
||||
|
||||
err = r.pubsubRestores.Publish(ctx, restoreSubmitChannel, string(messageJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to publish restore submit message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) SubscribeNodeForRestoresAssignment(
|
||||
nodeID uuid.UUID,
|
||||
handler func(restoreID uuid.UUID, isCallNotifier bool),
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
wrappedHandler := func(message string) {
|
||||
var msg RestoreSubmitMessage
|
||||
if err := json.Unmarshal([]byte(message), &msg); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal restore submit message", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if msg.NodeID != nodeID {
|
||||
return
|
||||
}
|
||||
|
||||
handler(msg.RestoreID, msg.IsCallNotifier)
|
||||
}
|
||||
|
||||
err := r.pubsubRestores.Subscribe(ctx, restoreSubmitChannel, wrappedHandler)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to restore submit channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Subscribed to restore submit channel", "nodeID", nodeID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) UnsubscribeNodeForRestoresAssignments() error {
|
||||
err := r.pubsubRestores.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unsubscribe from restore submit channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Unsubscribed from restore submit channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) PublishRestoreCompletion(
|
||||
nodeID uuid.UUID,
|
||||
restoreID uuid.UUID,
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
message := RestoreCompletionMessage{
|
||||
NodeID: nodeID,
|
||||
RestoreID: restoreID,
|
||||
}
|
||||
|
||||
messageJSON, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal restore completion message: %w", err)
|
||||
}
|
||||
|
||||
err = r.pubsubCompletions.Publish(ctx, restoreCompletionChannel, string(messageJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to publish restore completion message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) SubscribeForRestoresCompletions(
|
||||
handler func(nodeID uuid.UUID, restoreID uuid.UUID),
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
wrappedHandler := func(message string) {
|
||||
var msg RestoreCompletionMessage
|
||||
if err := json.Unmarshal([]byte(message), &msg); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal restore completion message", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
handler(msg.NodeID, msg.RestoreID)
|
||||
}
|
||||
|
||||
err := r.pubsubCompletions.Subscribe(ctx, restoreCompletionChannel, wrappedHandler)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to restore completion channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Subscribed to restore completion channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) UnsubscribeForRestoresCompletions() error {
|
||||
err := r.pubsubCompletions.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unsubscribe from restore completion channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Unsubscribed from restore completion channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) extractNodeIDFromKey(key, prefix, suffix string) uuid.UUID {
|
||||
nodeIDStr := strings.TrimPrefix(key, prefix)
|
||||
nodeIDStr = strings.TrimSuffix(nodeIDStr, suffix)
|
||||
|
||||
nodeID, err := uuid.Parse(nodeIDStr)
|
||||
if err != nil {
|
||||
r.logger.Warn("Failed to parse node ID from key", "key", key, "error", err)
|
||||
return uuid.Nil
|
||||
}
|
||||
|
||||
return nodeID
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) pipelineGetKeys(keys []string) (map[string][]byte, error) {
|
||||
if len(keys) == 0 {
|
||||
return make(map[string][]byte), nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
commands := make([]valkey.Completed, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
commands = append(commands, r.client.B().Get().Key(key).Build())
|
||||
}
|
||||
|
||||
results := r.client.DoMulti(ctx, commands...)
|
||||
|
||||
keyDataMap := make(map[string][]byte, len(keys))
|
||||
for i, result := range results {
|
||||
if result.Error() != nil {
|
||||
r.logger.Warn("Failed to get key in pipeline", "key", keys[i], "error", result.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
data, err := result.AsBytes()
|
||||
if err != nil {
|
||||
r.logger.Warn("Failed to parse key data in pipeline", "key", keys[i], "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
keyDataMap[keys[i]] = data
|
||||
}
|
||||
|
||||
return keyDataMap, nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) parseIntFromBytes(data []byte) (int64, error) {
|
||||
str := string(data)
|
||||
var count int64
|
||||
_, err := fmt.Sscanf(str, "%d", &count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to parse integer from bytes: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *RestoreNodesRegistry) cleanupDeadNodes() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
var allKeys []string
|
||||
cursor := uint64(0)
|
||||
pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix
|
||||
|
||||
for {
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to scan node keys: %w", result.Error())
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse scan result: %w", err)
|
||||
}
|
||||
|
||||
allKeys = append(allKeys, scanResult.Elements...)
|
||||
|
||||
cursor = scanResult.Cursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(allKeys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
keyDataMap, err := r.pipelineGetKeys(allKeys)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to pipeline get node keys: %w", err)
|
||||
}
|
||||
|
||||
threshold := time.Now().UTC().Add(-deadNodeThreshold)
|
||||
var deadNodeKeys []string
|
||||
|
||||
for key, data := range keyDataMap {
|
||||
// Skip if the key doesn't exist (data is empty)
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var node RestoreNode
|
||||
if err := json.Unmarshal(data, &node); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal node data during cleanup", "key", key, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip nodes with zero/uninitialized heartbeat
|
||||
if node.LastHeartbeat.IsZero() {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.LastHeartbeat.Before(threshold) {
|
||||
nodeID := node.ID.String()
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeID, nodeInfoKeySuffix)
|
||||
statsKey := fmt.Sprintf(
|
||||
"%s%s%s",
|
||||
nodeActiveRestoresPrefix,
|
||||
nodeID,
|
||||
nodeActiveRestoresSuffix,
|
||||
)
|
||||
|
||||
deadNodeKeys = append(deadNodeKeys, infoKey, statsKey)
|
||||
r.logger.Info(
|
||||
"Marking node for cleanup",
|
||||
"nodeID", nodeID,
|
||||
"lastHeartbeat", node.LastHeartbeat,
|
||||
"threshold", threshold,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if len(deadNodeKeys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
delCtx, delCancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer delCancel()
|
||||
|
||||
result := r.client.Do(
|
||||
delCtx,
|
||||
r.client.B().Del().Key(deadNodeKeys...).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to delete dead node keys: %w", result.Error())
|
||||
}
|
||||
|
||||
deletedCount, err := result.AsInt64()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse deleted count: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Cleaned up dead nodes", "deletedKeysCount", deletedCount)
|
||||
return nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
310
backend/internal/features/restores/restoring/restorer.go
Normal file
310
backend/internal/features/restores/restoring/restorer.go
Normal file
@@ -0,0 +1,310 @@
|
||||
package restoring
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
)
|
||||
|
||||
const (
|
||||
heartbeatTickerInterval = 15 * time.Second
|
||||
restorerHealthcheckThreshold = 5 * time.Minute
|
||||
)
|
||||
|
||||
type RestorerNode struct {
|
||||
nodeID uuid.UUID
|
||||
|
||||
databaseService *databases.DatabaseService
|
||||
backupService *backups.BackupService
|
||||
fieldEncryptor util_encryption.FieldEncryptor
|
||||
restoreRepository *restores_core.RestoreRepository
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
storageService *storages.StorageService
|
||||
restoreNodesRegistry *RestoreNodesRegistry
|
||||
logger *slog.Logger
|
||||
restoreBackupUsecase restores_core.RestoreBackupUsecase
|
||||
cacheUtil *cache_utils.CacheUtil[RestoreDatabaseCache]
|
||||
restoreCancelManager *tasks_cancellation.TaskCancelManager
|
||||
|
||||
lastHeartbeat time.Time
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (n *RestorerNode) Run(ctx context.Context) {
|
||||
wasAlreadyRun := n.hasRun.Load()
|
||||
|
||||
n.runOnce.Do(func() {
|
||||
n.hasRun.Store(true)
|
||||
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
|
||||
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
|
||||
|
||||
restoreNode := RestoreNode{
|
||||
ID: n.nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
}
|
||||
|
||||
if err := n.restoreNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), restoreNode); err != nil {
|
||||
n.logger.Error("Failed to register node in registry", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
restoreHandler := func(restoreID uuid.UUID, isCallNotifier bool) {
|
||||
n.MakeRestore(restoreID)
|
||||
if err := n.restoreNodesRegistry.PublishRestoreCompletion(n.nodeID, restoreID); err != nil {
|
||||
n.logger.Error(
|
||||
"Failed to publish restore completion",
|
||||
"error",
|
||||
err,
|
||||
"restoreID",
|
||||
restoreID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
err := n.restoreNodesRegistry.SubscribeNodeForRestoresAssignment(
|
||||
n.nodeID,
|
||||
restoreHandler,
|
||||
)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to subscribe to restore assignments", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := n.restoreNodesRegistry.UnsubscribeNodeForRestoresAssignments(); err != nil {
|
||||
n.logger.Error("Failed to unsubscribe from restore assignments", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(heartbeatTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
n.logger.Info("Restore node started", "nodeID", n.nodeID, "throughput", throughputMBs)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
|
||||
|
||||
if err := n.restoreNodesRegistry.UnregisterNodeFromRegistry(restoreNode); err != nil {
|
||||
n.logger.Error("Failed to unregister node from registry", "error", err)
|
||||
}
|
||||
|
||||
return
|
||||
case <-ticker.C:
|
||||
n.sendHeartbeat(&restoreNode)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", n))
|
||||
}
|
||||
}
|
||||
|
||||
func (n *RestorerNode) IsRestorerRunning() bool {
|
||||
return n.lastHeartbeat.After(time.Now().UTC().Add(-restorerHealthcheckThreshold))
|
||||
}
|
||||
|
||||
func (n *RestorerNode) MakeRestore(restoreID uuid.UUID) {
|
||||
// Get and delete cached DB credentials atomically
|
||||
dbCache := n.cacheUtil.GetAndDelete(restoreID.String())
|
||||
|
||||
if dbCache == nil {
|
||||
// Cache miss - fail immediately
|
||||
restore, err := n.restoreRepository.FindByID(restoreID)
|
||||
if err != nil {
|
||||
n.logger.Error(
|
||||
"Failed to get restore by ID after cache miss",
|
||||
"restoreId",
|
||||
restoreID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
errMsg := "Database credentials expired or missing from cache (most likely due to instance restart)"
|
||||
restore.FailMessage = &errMsg
|
||||
restore.Status = restores_core.RestoreStatusFailed
|
||||
|
||||
if err := n.restoreRepository.Save(restore); err != nil {
|
||||
n.logger.Error("Failed to save restore after cache miss", "error", err)
|
||||
}
|
||||
|
||||
n.logger.Error("Restore failed: cache miss", "restoreId", restoreID)
|
||||
return
|
||||
}
|
||||
|
||||
restore, err := n.restoreRepository.FindByID(restoreID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get restore by ID", "restoreId", restoreID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
backup, err := n.backupService.GetBackup(restore.BackupID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get backup by ID", "backupId", restore.BackupID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
databaseID := backup.DatabaseID
|
||||
|
||||
database, err := n.databaseService.GetDatabaseByID(databaseID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get database by ID", "databaseId", databaseID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
backupConfig, err := n.backupConfigService.GetBackupConfigByDbId(databaseID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if backupConfig.StorageID == nil {
|
||||
n.logger.Error("Backup config storage ID is not defined")
|
||||
return
|
||||
}
|
||||
|
||||
storage, err := n.storageService.GetStorageByID(*backupConfig.StorageID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get storage by ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now().UTC()
|
||||
|
||||
// Create cancellable context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
n.restoreCancelManager.RegisterTask(restore.ID, cancel)
|
||||
defer n.restoreCancelManager.UnregisterTask(restore.ID)
|
||||
|
||||
// Create restoring database from cached credentials
|
||||
restoringToDB := &databases.Database{
|
||||
Type: database.Type,
|
||||
Postgresql: dbCache.PostgresqlDatabase,
|
||||
Mysql: dbCache.MysqlDatabase,
|
||||
Mariadb: dbCache.MariadbDatabase,
|
||||
Mongodb: dbCache.MongodbDatabase,
|
||||
}
|
||||
|
||||
if err := restoringToDB.PopulateDbData(n.logger, n.fieldEncryptor); err != nil {
|
||||
errMsg := fmt.Sprintf("failed to auto-detect database data: %v", err)
|
||||
restore.FailMessage = &errMsg
|
||||
restore.Status = restores_core.RestoreStatusFailed
|
||||
restore.RestoreDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
if err := n.restoreRepository.Save(restore); err != nil {
|
||||
n.logger.Error("Failed to save restore", "error", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
isExcludeExtensions := false
|
||||
if dbCache.PostgresqlDatabase != nil {
|
||||
isExcludeExtensions = dbCache.PostgresqlDatabase.IsExcludeExtensions
|
||||
}
|
||||
|
||||
err = n.restoreBackupUsecase.Execute(
|
||||
ctx,
|
||||
backupConfig,
|
||||
*restore,
|
||||
database,
|
||||
restoringToDB,
|
||||
backup,
|
||||
storage,
|
||||
isExcludeExtensions,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
|
||||
// Check if restore was cancelled
|
||||
isCancelled := strings.Contains(errMsg, "restore cancelled") ||
|
||||
strings.Contains(errMsg, "context canceled") ||
|
||||
errors.Is(err, context.Canceled)
|
||||
isShutdown := strings.Contains(errMsg, "shutdown")
|
||||
|
||||
if isCancelled && !isShutdown {
|
||||
n.logger.Warn("Restore was cancelled by user or system",
|
||||
"restoreId", restore.ID,
|
||||
"isCancelled", isCancelled,
|
||||
"isShutdown", isShutdown,
|
||||
)
|
||||
|
||||
restore.Status = restores_core.RestoreStatusCanceled
|
||||
restore.RestoreDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
if err := n.restoreRepository.Save(restore); err != nil {
|
||||
n.logger.Error("Failed to save cancelled restore", "error", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
n.logger.Error("Restore execution failed",
|
||||
"restoreId", restore.ID,
|
||||
"backupId", backup.ID,
|
||||
"databaseId", databaseID,
|
||||
"databaseType", database.Type,
|
||||
"storageId", storage.ID,
|
||||
"storageType", storage.Type,
|
||||
"error", err,
|
||||
"errorMessage", errMsg,
|
||||
)
|
||||
|
||||
restore.FailMessage = &errMsg
|
||||
restore.Status = restores_core.RestoreStatusFailed
|
||||
restore.RestoreDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
if err := n.restoreRepository.Save(restore); err != nil {
|
||||
n.logger.Error("Failed to save restore", "error", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
restore.Status = restores_core.RestoreStatusCompleted
|
||||
restore.RestoreDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
if err := n.restoreRepository.Save(restore); err != nil {
|
||||
n.logger.Error("Failed to save restore", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
n.logger.Info(
|
||||
"Restore completed successfully",
|
||||
"restoreId", restore.ID,
|
||||
"backupId", backup.ID,
|
||||
"durationMs", restore.RestoreDurationMs,
|
||||
)
|
||||
}
|
||||
|
||||
func (n *RestorerNode) sendHeartbeat(restoreNode *RestoreNode) {
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
if err := n.restoreNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *restoreNode); err != nil {
|
||||
n.logger.Error("Failed to send heartbeat", "error", err)
|
||||
}
|
||||
}
|
||||
164
backend/internal/features/restores/restoring/restorer_test.go
Normal file
164
backend/internal/features/restores/restoring/restorer_test.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package restoring
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
)
|
||||
|
||||
func Test_MakeRestore_WhenCacheMissed_RestoreFails(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
defer func() {
|
||||
backupRepo := backups_core.BackupRepository{}
|
||||
backupsList, _ := backupRepo.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backupsList {
|
||||
backupRepo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
restoreRepo := restores_core.RestoreRepository{}
|
||||
restoresInProgress, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
|
||||
for _, restore := range restoresInProgress {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
restoresFailed, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusFailed)
|
||||
for _, restore := range restoresFailed {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
||||
|
||||
// Create restore but DON'T cache DB credentials
|
||||
// Also don't set embedded DB fields to avoid schema issues
|
||||
restore := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
}
|
||||
err := restoreRepository.Save(restore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create restorer and execute restore (should fail due to cache miss)
|
||||
restorerNode := CreateTestRestorerNode()
|
||||
restorerNode.MakeRestore(restore.ID)
|
||||
|
||||
// Verify restore failed with appropriate error message
|
||||
updatedRestore, err := restoreRepository.FindByID(restore.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, restores_core.RestoreStatusFailed, updatedRestore.Status)
|
||||
assert.NotNil(t, updatedRestore.FailMessage)
|
||||
assert.Contains(
|
||||
t,
|
||||
*updatedRestore.FailMessage,
|
||||
"Database credentials expired or missing from cache",
|
||||
)
|
||||
}
|
||||
|
||||
func Test_MakeRestore_WhenTaskStarts_CacheDeletedImmediately(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
defer func() {
|
||||
backupRepo := backups_core.BackupRepository{}
|
||||
backupsList, _ := backupRepo.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backupsList {
|
||||
backupRepo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
restoreRepo := restores_core.RestoreRepository{}
|
||||
restoresInProgress, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
|
||||
for _, restore := range restoresInProgress {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
restoresFailed, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusFailed)
|
||||
for _, restore := range restoresFailed {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
restoresCompleted, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusCompleted)
|
||||
for _, restore := range restoresCompleted {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
||||
|
||||
// Create restore with cached DB credentials
|
||||
// Don't set embedded DB fields in the restore model itself
|
||||
restore := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
}
|
||||
err := restoreRepository.Save(restore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Cache DB credentials separately
|
||||
dbCache := &RestoreDatabaseCache{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "test",
|
||||
Password: "test",
|
||||
Database: stringPtr("testdb"),
|
||||
Version: "16",
|
||||
},
|
||||
}
|
||||
restoreDatabaseCache.SetWithExpiration(restore.ID.String(), dbCache, 1*time.Hour)
|
||||
|
||||
// Verify cache exists before restore starts
|
||||
cachedDB := restoreDatabaseCache.Get(restore.ID.String())
|
||||
assert.NotNil(t, cachedDB, "Cache should exist before restore starts")
|
||||
|
||||
// Start restore (this will call GetAndDelete)
|
||||
restorerNode := CreateTestRestorerNode()
|
||||
restorerNode.MakeRestore(restore.ID)
|
||||
|
||||
// Verify cache was deleted immediately
|
||||
cachedDBAfter := restoreDatabaseCache.Get(restore.ID.String())
|
||||
assert.Nil(t, cachedDBAfter, "Cache should be deleted immediately when task starts")
|
||||
}
|
||||
410
backend/internal/features/restores/restoring/scheduler.go
Normal file
410
backend/internal/features/restores/restoring/scheduler.go
Normal file
@@ -0,0 +1,410 @@
|
||||
package restoring
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
)
|
||||
|
||||
const (
|
||||
schedulerStartupDelay = 1 * time.Minute
|
||||
schedulerTickerInterval = 1 * time.Minute
|
||||
schedulerHealthcheckThreshold = 5 * time.Minute
|
||||
)
|
||||
|
||||
type RestoresScheduler struct {
|
||||
restoreRepository *restores_core.RestoreRepository
|
||||
backupService *backups.BackupService
|
||||
storageService *storages.StorageService
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
restoreNodesRegistry *RestoreNodesRegistry
|
||||
lastCheckTime time.Time
|
||||
logger *slog.Logger
|
||||
restoreToNodeRelations map[uuid.UUID]RestoreToNodeRelation
|
||||
restorerNode *RestorerNode
|
||||
cacheUtil *cache_utils.CacheUtil[RestoreDatabaseCache]
|
||||
completionSubscriptionID uuid.UUID
|
||||
|
||||
runOnce sync.Once
|
||||
hasRun atomic.Bool
|
||||
}
|
||||
|
||||
func (s *RestoresScheduler) Run(ctx context.Context) {
|
||||
wasAlreadyRun := s.hasRun.Load()
|
||||
|
||||
s.runOnce.Do(func() {
|
||||
s.hasRun.Store(true)
|
||||
|
||||
s.lastCheckTime = time.Now().UTC()
|
||||
|
||||
if config.GetEnv().IsManyNodesMode {
|
||||
// wait other nodes to start
|
||||
time.Sleep(schedulerStartupDelay)
|
||||
}
|
||||
|
||||
if err := s.failRestoresInProgress(); err != nil {
|
||||
s.logger.Error("Failed to fail restores in progress", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err := s.restoreNodesRegistry.SubscribeForRestoresCompletions(s.onRestoreCompleted)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to subscribe to restore completions", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := s.restoreNodesRegistry.UnsubscribeForRestoresCompletions(); err != nil {
|
||||
s.logger.Error("Failed to unsubscribe from restore completions", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(schedulerTickerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.checkDeadNodesAndFailRestores(); err != nil {
|
||||
s.logger.Error("Failed to check dead nodes and fail restores", "error", err)
|
||||
}
|
||||
|
||||
s.lastCheckTime = time.Now().UTC()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if wasAlreadyRun {
|
||||
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RestoresScheduler) IsSchedulerRunning() bool {
|
||||
return s.lastCheckTime.After(time.Now().UTC().Add(-schedulerHealthcheckThreshold))
|
||||
}
|
||||
|
||||
func (s *RestoresScheduler) failRestoresInProgress() error {
|
||||
restoresInProgress, err := s.restoreRepository.FindByStatus(
|
||||
restores_core.RestoreStatusInProgress,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, restore := range restoresInProgress {
|
||||
failMessage := "Restore failed due to application restart"
|
||||
restore.FailMessage = &failMessage
|
||||
restore.Status = restores_core.RestoreStatusFailed
|
||||
|
||||
if err := s.restoreRepository.Save(restore); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RestoresScheduler) StartRestore(restoreID uuid.UUID, dbCache *RestoreDatabaseCache) error {
|
||||
// If dbCache not provided, try to fetch from DB (for backward compatibility/testing)
|
||||
if dbCache == nil {
|
||||
restore, err := s.restoreRepository.FindByID(restoreID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to find restore by ID",
|
||||
"restoreId",
|
||||
restoreID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// Create cache DTO from restore (may be nil if not in DB)
|
||||
dbCache = &RestoreDatabaseCache{
|
||||
PostgresqlDatabase: restore.PostgresqlDatabase,
|
||||
MysqlDatabase: restore.MysqlDatabase,
|
||||
MariadbDatabase: restore.MariadbDatabase,
|
||||
MongodbDatabase: restore.MongodbDatabase,
|
||||
}
|
||||
}
|
||||
|
||||
// Cache database credentials with 1-hour expiration
|
||||
s.cacheUtil.SetWithExpiration(restoreID.String(), dbCache, 1*time.Hour)
|
||||
|
||||
leastBusyNodeID, err := s.calculateLeastBusyNode()
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to calculate least busy node",
|
||||
"restoreId",
|
||||
restoreID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.restoreNodesRegistry.IncrementRestoresInProgress(*leastBusyNodeID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to increment restores in progress",
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
"restoreId",
|
||||
restoreID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.restoreNodesRegistry.AssignRestoreToNode(*leastBusyNodeID, restoreID, false); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to submit restore",
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
"restoreId",
|
||||
restoreID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
if decrementErr := s.restoreNodesRegistry.DecrementRestoresInProgress(*leastBusyNodeID); decrementErr != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement restores in progress after submit failure",
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
"error",
|
||||
decrementErr,
|
||||
)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if relation, exists := s.restoreToNodeRelations[*leastBusyNodeID]; exists {
|
||||
relation.RestoreIDs = append(relation.RestoreIDs, restoreID)
|
||||
s.restoreToNodeRelations[*leastBusyNodeID] = relation
|
||||
} else {
|
||||
s.restoreToNodeRelations[*leastBusyNodeID] = RestoreToNodeRelation{
|
||||
NodeID: *leastBusyNodeID,
|
||||
RestoreIDs: []uuid.UUID{restoreID},
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Info(
|
||||
"Successfully triggered restore",
|
||||
"restoreId",
|
||||
restoreID,
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RestoresScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
|
||||
nodes, err := s.restoreNodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get available nodes: %w", err)
|
||||
}
|
||||
|
||||
if len(nodes) == 0 {
|
||||
return nil, fmt.Errorf("no nodes available")
|
||||
}
|
||||
|
||||
stats, err := s.restoreNodesRegistry.GetRestoreNodesStats()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get restore nodes stats: %w", err)
|
||||
}
|
||||
|
||||
statsMap := make(map[uuid.UUID]int)
|
||||
for _, stat := range stats {
|
||||
statsMap[stat.ID] = stat.ActiveRestores
|
||||
}
|
||||
|
||||
var bestNode *RestoreNode
|
||||
var bestScore float64 = -1
|
||||
|
||||
for i := range nodes {
|
||||
node := &nodes[i]
|
||||
|
||||
activeRestores := statsMap[node.ID]
|
||||
|
||||
var score float64
|
||||
if node.ThroughputMBs > 0 {
|
||||
score = float64(activeRestores) / float64(node.ThroughputMBs)
|
||||
} else {
|
||||
score = float64(activeRestores) * 1000
|
||||
}
|
||||
|
||||
if bestNode == nil || score < bestScore {
|
||||
bestNode = node
|
||||
bestScore = score
|
||||
}
|
||||
}
|
||||
|
||||
if bestNode == nil {
|
||||
return nil, fmt.Errorf("no suitable nodes available")
|
||||
}
|
||||
|
||||
return &bestNode.ID, nil
|
||||
}
|
||||
|
||||
func (s *RestoresScheduler) onRestoreCompleted(nodeID uuid.UUID, restoreID uuid.UUID) {
|
||||
// Verify this task is actually a restore (registry contains multiple task types)
|
||||
_, err := s.restoreRepository.FindByID(restoreID)
|
||||
if err != nil {
|
||||
// Not a restore task, ignore it
|
||||
return
|
||||
}
|
||||
|
||||
relation, exists := s.restoreToNodeRelations[nodeID]
|
||||
if !exists {
|
||||
s.logger.Warn(
|
||||
"Received completion for unknown node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"restoreId",
|
||||
restoreID,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
newRestoreIDs := make([]uuid.UUID, 0)
|
||||
found := false
|
||||
for _, id := range relation.RestoreIDs {
|
||||
if id == restoreID {
|
||||
found = true
|
||||
continue
|
||||
}
|
||||
newRestoreIDs = append(newRestoreIDs, id)
|
||||
}
|
||||
|
||||
if !found {
|
||||
s.logger.Warn(
|
||||
"Restore not found in node's restore list",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"restoreId",
|
||||
restoreID,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if len(newRestoreIDs) == 0 {
|
||||
delete(s.restoreToNodeRelations, nodeID)
|
||||
} else {
|
||||
relation.RestoreIDs = newRestoreIDs
|
||||
s.restoreToNodeRelations[nodeID] = relation
|
||||
}
|
||||
|
||||
if err := s.restoreNodesRegistry.DecrementRestoresInProgress(nodeID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement restores in progress",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"restoreId",
|
||||
restoreID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RestoresScheduler) checkDeadNodesAndFailRestores() error {
|
||||
nodes, err := s.restoreNodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get available nodes: %w", err)
|
||||
}
|
||||
|
||||
aliveNodeIDs := make(map[uuid.UUID]bool)
|
||||
for _, node := range nodes {
|
||||
aliveNodeIDs[node.ID] = true
|
||||
}
|
||||
|
||||
for nodeID, relation := range s.restoreToNodeRelations {
|
||||
if aliveNodeIDs[nodeID] {
|
||||
continue
|
||||
}
|
||||
|
||||
s.logger.Warn(
|
||||
"Node is dead, failing its restores",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"restoreCount",
|
||||
len(relation.RestoreIDs),
|
||||
)
|
||||
|
||||
for _, restoreID := range relation.RestoreIDs {
|
||||
restore, err := s.restoreRepository.FindByID(restoreID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to find restore for dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"restoreId",
|
||||
restoreID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
failMessage := "Restore failed due to node unavailability"
|
||||
restore.FailMessage = &failMessage
|
||||
restore.Status = restores_core.RestoreStatusFailed
|
||||
|
||||
if err := s.restoreRepository.Save(restore); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to save failed restore for dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"restoreId",
|
||||
restoreID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := s.restoreNodesRegistry.DecrementRestoresInProgress(nodeID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement restores in progress for dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"restoreId",
|
||||
restoreID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
s.logger.Info(
|
||||
"Failed restore due to dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"restoreId",
|
||||
restoreID,
|
||||
)
|
||||
}
|
||||
|
||||
delete(s.restoreToNodeRelations, nodeID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
856
backend/internal/features/restores/restoring/scheduler_test.go
Normal file
856
backend/internal/features/restores/restoring/scheduler_test.go
Normal file
@@ -0,0 +1,856 @@
|
||||
package restoring
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_CheckDeadNodesAndFailRestores_NodeDies_FailsRestoreAndCleansUpRegistry(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
var mockNodeID uuid.UUID
|
||||
|
||||
defer func() {
|
||||
backupRepo := backups_core.BackupRepository{}
|
||||
backups, _ := backupRepo.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
restoreRepo := restores_core.RestoreRepository{}
|
||||
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
restores, _ = restoreRepo.FindByStatus(restores_core.RestoreStatusFailed)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
// Clean up mock node
|
||||
if mockNodeID != uuid.Nil {
|
||||
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: mockNodeID})
|
||||
}
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
// Create a test backup
|
||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
||||
|
||||
var err error
|
||||
// Register mock node without subscribing to restores (simulates node crash after registration)
|
||||
mockNodeID = uuid.New()
|
||||
err = CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC())
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create restore and assign to mock node
|
||||
restore := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
}
|
||||
err = restoreRepository.Save(restore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Scheduler assigns restore to mock node
|
||||
err = GetRestoresScheduler().StartRestore(restore.ID, nil)
|
||||
assert.NoError(t, err)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify Valkey counter was incremented when restore was assigned
|
||||
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
|
||||
assert.NoError(t, err)
|
||||
foundStat := false
|
||||
for _, stat := range stats {
|
||||
if stat.ID == mockNodeID {
|
||||
assert.Equal(t, 1, stat.ActiveRestores)
|
||||
foundStat = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, foundStat, "Node stats should be present")
|
||||
|
||||
// Simulate node death by setting heartbeat older than 2-minute threshold
|
||||
oldHeartbeat := time.Now().UTC().Add(-3 * time.Minute)
|
||||
err = UpdateNodeHeartbeatDirectly(mockNodeID, 100, oldHeartbeat)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Trigger dead node detection
|
||||
err = GetRestoresScheduler().checkDeadNodesAndFailRestores()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify restore was failed with appropriate error message
|
||||
failedRestore, err := restoreRepository.FindByID(restore.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, restores_core.RestoreStatusFailed, failedRestore.Status)
|
||||
assert.NotNil(t, failedRestore.FailMessage)
|
||||
assert.Contains(t, *failedRestore.FailMessage, "node unavailability")
|
||||
|
||||
// Verify Valkey counter was decremented after restore failed
|
||||
stats, err = restoreNodesRegistry.GetRestoreNodesStats()
|
||||
assert.NoError(t, err)
|
||||
for _, stat := range stats {
|
||||
if stat.ID == mockNodeID {
|
||||
assert.Equal(t, 0, stat.ActiveRestores)
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_OnRestoreCompleted_TaskIsNotRestore_SkipsProcessing(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
var mockNodeID uuid.UUID
|
||||
|
||||
defer func() {
|
||||
backupRepo := backups_core.BackupRepository{}
|
||||
backups, _ := backupRepo.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
restoreRepo := restores_core.RestoreRepository{}
|
||||
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
// Clean up mock node
|
||||
if mockNodeID != uuid.Nil {
|
||||
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: mockNodeID})
|
||||
}
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
// Create a test backup
|
||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
||||
|
||||
// Register mock node
|
||||
mockNodeID = uuid.New()
|
||||
err := CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC())
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create restore and assign to the node
|
||||
restore := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
}
|
||||
err = restoreRepository.Save(restore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = GetRestoresScheduler().StartRestore(restore.ID, nil)
|
||||
assert.NoError(t, err)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Get initial state of the registry
|
||||
initialStats, err := restoreNodesRegistry.GetRestoreNodesStats()
|
||||
assert.NoError(t, err)
|
||||
var initialActiveTasks int
|
||||
for _, stat := range initialStats {
|
||||
if stat.ID == mockNodeID {
|
||||
initialActiveTasks = stat.ActiveRestores
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, initialActiveTasks, "Should have 1 active task")
|
||||
|
||||
// Call onRestoreCompleted with a random UUID (not a restore ID)
|
||||
nonRestoreTaskID := uuid.New()
|
||||
GetRestoresScheduler().onRestoreCompleted(mockNodeID, nonRestoreTaskID)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify: Active tasks counter should remain the same (not decremented)
|
||||
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
|
||||
assert.NoError(t, err)
|
||||
for _, stat := range stats {
|
||||
if stat.ID == mockNodeID {
|
||||
assert.Equal(t, initialActiveTasks, stat.ActiveRestores,
|
||||
"Active tasks should not change for non-restore task")
|
||||
}
|
||||
}
|
||||
|
||||
// Verify: restore should still be in progress (not modified)
|
||||
unchangedRestore, err := restoreRepository.FindByID(restore.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, restores_core.RestoreStatusInProgress, unchangedRestore.Status,
|
||||
"Restore status should not change for non-restore task completion")
|
||||
|
||||
// Verify: restoreToNodeRelations should still contain the node
|
||||
scheduler := GetRestoresScheduler()
|
||||
_, exists := scheduler.restoreToNodeRelations[mockNodeID]
|
||||
assert.True(t, exists, "Node should still be in restoreToNodeRelations")
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
|
||||
t.Run("Nodes with same throughput", func(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
node1ID := uuid.New()
|
||||
node2ID := uuid.New()
|
||||
node3ID := uuid.New()
|
||||
now := time.Now().UTC()
|
||||
|
||||
defer func() {
|
||||
// Clean up all mock nodes
|
||||
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: node1ID})
|
||||
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: node2ID})
|
||||
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: node3ID})
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
err := CreateMockNodeInRegistry(node1ID, 100, now)
|
||||
assert.NoError(t, err)
|
||||
err = CreateMockNodeInRegistry(node2ID, 100, now)
|
||||
assert.NoError(t, err)
|
||||
err = CreateMockNodeInRegistry(node3ID, 100, now)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for range 5 {
|
||||
err = restoreNodesRegistry.IncrementRestoresInProgress(node1ID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
for range 2 {
|
||||
err = restoreNodesRegistry.IncrementRestoresInProgress(node2ID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
for range 8 {
|
||||
err = restoreNodesRegistry.IncrementRestoresInProgress(node3ID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
leastBusyNodeID, err := GetRestoresScheduler().calculateLeastBusyNode()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, leastBusyNodeID)
|
||||
assert.Equal(t, node2ID, *leastBusyNodeID)
|
||||
})
|
||||
|
||||
t.Run("Nodes with different throughput", func(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
node100MBsID := uuid.New()
|
||||
node50MBsID := uuid.New()
|
||||
now := time.Now().UTC()
|
||||
|
||||
defer func() {
|
||||
// Clean up all mock nodes
|
||||
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: node100MBsID})
|
||||
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: node50MBsID})
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
err := CreateMockNodeInRegistry(node100MBsID, 100, now)
|
||||
assert.NoError(t, err)
|
||||
err = CreateMockNodeInRegistry(node50MBsID, 50, now)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for range 10 {
|
||||
err = restoreNodesRegistry.IncrementRestoresInProgress(node100MBsID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
err = restoreNodesRegistry.IncrementRestoresInProgress(node50MBsID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
leastBusyNodeID, err := GetRestoresScheduler().calculateLeastBusyNode()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, leastBusyNodeID)
|
||||
assert.Equal(t, node50MBsID, *leastBusyNodeID)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_FailRestoresInProgress_SchedulerStarts_UpdatesStatus(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backupRepo := backups_core.BackupRepository{}
|
||||
backups, _ := backupRepo.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
restoreRepo := restores_core.RestoreRepository{}
|
||||
|
||||
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
restores, _ = restoreRepo.FindByStatus(restores_core.RestoreStatusFailed)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
restores, _ = restoreRepo.FindByStatus(restores_core.RestoreStatusCompleted)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
// Create a test backup
|
||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
||||
|
||||
// Create two in-progress restores that should be failed on scheduler restart
|
||||
restore1 := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
CreatedAt: time.Now().UTC().Add(-30 * time.Minute),
|
||||
}
|
||||
err := restoreRepository.Save(restore1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
restore2 := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
CreatedAt: time.Now().UTC().Add(-15 * time.Minute),
|
||||
}
|
||||
err = restoreRepository.Save(restore2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create a completed restore to verify it's not affected by failRestoresInProgress
|
||||
completedRestore := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusCompleted,
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
}
|
||||
err = restoreRepository.Save(completedRestore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Trigger the scheduler's failRestoresInProgress logic
|
||||
// This should mark in-progress restores as failed
|
||||
err = GetRestoresScheduler().failRestoresInProgress()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify all restores exist and were processed correctly
|
||||
allRestores1, err := restoreRepository.FindByID(restore1.ID)
|
||||
assert.NoError(t, err)
|
||||
allRestores2, err := restoreRepository.FindByID(restore2.ID)
|
||||
assert.NoError(t, err)
|
||||
allRestores3, err := restoreRepository.FindByID(completedRestore.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var failedCount int
|
||||
var completedCount int
|
||||
|
||||
restoresToCheck := []*restores_core.Restore{allRestores1, allRestores2, allRestores3}
|
||||
for _, restore := range restoresToCheck {
|
||||
switch restore.Status {
|
||||
case restores_core.RestoreStatusFailed:
|
||||
failedCount++
|
||||
// Verify fail message indicates application restart
|
||||
assert.NotNil(t, restore.FailMessage)
|
||||
assert.Equal(t, "Restore failed due to application restart", *restore.FailMessage)
|
||||
case restores_core.RestoreStatusCompleted:
|
||||
completedCount++
|
||||
}
|
||||
}
|
||||
|
||||
// Verify correct number of restores in each state
|
||||
assert.Equal(t, 2, failedCount, "Should have 2 failed restores (originally in progress)")
|
||||
assert.Equal(t, 1, completedCount, "Should have 1 completed restore (unchanged)")
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_StartRestore_RestoreCompletes_DecrementsActiveTaskCount(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
// Start scheduler so it can handle task completions
|
||||
scheduler := CreateTestRestoresScheduler()
|
||||
schedulerCancel := StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
restorerNode := CreateTestRestorerNode()
|
||||
restorerNode.restoreBackupUsecase = &MockSuccessRestoreUsecase{}
|
||||
|
||||
cancel := StartRestorerNodeForTest(t, restorerNode)
|
||||
defer StopRestorerNodeForTest(t, cancel, restorerNode)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backupRepo := backups_core.BackupRepository{}
|
||||
backups, _ := backupRepo.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
restoreRepo := restores_core.RestoreRepository{}
|
||||
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusCompleted)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
// Create a test backup
|
||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
||||
|
||||
// Get initial active task count
|
||||
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
|
||||
assert.NoError(t, err)
|
||||
var initialActiveTasks int
|
||||
for _, stat := range stats {
|
||||
if stat.ID == restorerNode.nodeID {
|
||||
initialActiveTasks = stat.ActiveRestores
|
||||
break
|
||||
}
|
||||
}
|
||||
t.Logf("Initial active tasks: %d", initialActiveTasks)
|
||||
|
||||
// Create and start restore
|
||||
restore := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
}
|
||||
err = restoreRepository.Save(restore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = scheduler.StartRestore(restore.ID, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Wait for restore to complete
|
||||
WaitForRestoreCompletion(t, restore.ID, 10*time.Second)
|
||||
|
||||
// Verify restore was completed
|
||||
completedRestore, err := restoreRepository.FindByID(restore.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, completedRestore.Status)
|
||||
|
||||
// Wait for active task count to decrease
|
||||
decreased := WaitForActiveTasksDecrease(
|
||||
t,
|
||||
restorerNode.nodeID,
|
||||
initialActiveTasks+1,
|
||||
10*time.Second,
|
||||
)
|
||||
assert.True(t, decreased, "Active task count should have decreased after restore completion")
|
||||
|
||||
// Verify final active task count equals initial count
|
||||
finalStats, err := restoreNodesRegistry.GetRestoreNodesStats()
|
||||
assert.NoError(t, err)
|
||||
for _, stat := range finalStats {
|
||||
if stat.ID == restorerNode.nodeID {
|
||||
t.Logf("Final active tasks: %d", stat.ActiveRestores)
|
||||
assert.Equal(t, initialActiveTasks, stat.ActiveRestores,
|
||||
"Active task count should return to initial value after restore completion")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_StartRestore_RestoreFails_DecrementsActiveTaskCount(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
// Start scheduler so it can handle task completions
|
||||
scheduler := CreateTestRestoresScheduler()
|
||||
schedulerCancel := StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
restorerNode := CreateTestRestorerNode()
|
||||
restorerNode.restoreBackupUsecase = &MockFailedRestoreUsecase{}
|
||||
|
||||
cancel := StartRestorerNodeForTest(t, restorerNode)
|
||||
defer StopRestorerNodeForTest(t, cancel, restorerNode)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backupRepo := backups_core.BackupRepository{}
|
||||
backups, _ := backupRepo.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
restoreRepo := restores_core.RestoreRepository{}
|
||||
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusFailed)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
// Create a test backup
|
||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
||||
|
||||
// Get initial active task count
|
||||
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
|
||||
assert.NoError(t, err)
|
||||
var initialActiveTasks int
|
||||
for _, stat := range stats {
|
||||
if stat.ID == restorerNode.nodeID {
|
||||
initialActiveTasks = stat.ActiveRestores
|
||||
break
|
||||
}
|
||||
}
|
||||
t.Logf("Initial active tasks: %d", initialActiveTasks)
|
||||
|
||||
// Create and start restore
|
||||
restore := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
}
|
||||
err = restoreRepository.Save(restore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = scheduler.StartRestore(restore.ID, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Wait for restore to fail
|
||||
WaitForRestoreCompletion(t, restore.ID, 10*time.Second)
|
||||
|
||||
// Verify restore failed
|
||||
failedRestore, err := restoreRepository.FindByID(restore.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, restores_core.RestoreStatusFailed, failedRestore.Status)
|
||||
|
||||
// Wait for active task count to decrease
|
||||
decreased := WaitForActiveTasksDecrease(
|
||||
t,
|
||||
restorerNode.nodeID,
|
||||
initialActiveTasks+1,
|
||||
10*time.Second,
|
||||
)
|
||||
assert.True(t, decreased, "Active task count should have decreased after restore failure")
|
||||
|
||||
// Verify final active task count equals initial count
|
||||
finalStats, err := restoreNodesRegistry.GetRestoreNodesStats()
|
||||
assert.NoError(t, err)
|
||||
for _, stat := range finalStats {
|
||||
if stat.ID == restorerNode.nodeID {
|
||||
t.Logf("Final active tasks: %d", stat.ActiveRestores)
|
||||
assert.Equal(t, initialActiveTasks, stat.ActiveRestores,
|
||||
"Active task count should return to initial value after restore failure")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_StartRestore_CredentialsStoredEncryptedInCache(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
var mockNodeID uuid.UUID
|
||||
|
||||
defer func() {
|
||||
backupRepo := backups_core.BackupRepository{}
|
||||
backups, _ := backupRepo.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
restoreRepo := restores_core.RestoreRepository{}
|
||||
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
// Clean up mock node
|
||||
if mockNodeID != uuid.Nil {
|
||||
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: mockNodeID})
|
||||
}
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
// Create a test backup
|
||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
||||
|
||||
// Register mock node so scheduler can assign restore to it
|
||||
mockNodeID = uuid.New()
|
||||
err := CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC())
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create restore with plaintext credentials
|
||||
plaintextPassword := "test_password_123"
|
||||
restore := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
}
|
||||
err = restoreRepository.Save(restore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create PostgreSQL database credentials with plaintext password
|
||||
postgresDB := &postgresql.PostgresqlDatabase{
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "testuser",
|
||||
Password: plaintextPassword,
|
||||
Database: stringPtr("testdb"),
|
||||
Version: "16",
|
||||
}
|
||||
|
||||
// Encrypt password using FieldEncryptor (same as production flow)
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
err = postgresDB.EncryptSensitiveFields(database.ID, encryptor)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify password was encrypted (different from plaintext)
|
||||
assert.NotEqual(t, plaintextPassword, postgresDB.Password,
|
||||
"Password should be encrypted, not plaintext")
|
||||
|
||||
// Create cache with encrypted credentials
|
||||
dbCache := &RestoreDatabaseCache{
|
||||
PostgresqlDatabase: postgresDB,
|
||||
}
|
||||
|
||||
// Call StartRestore to cache credentials (do NOT start restore node)
|
||||
err = GetRestoresScheduler().StartRestore(restore.ID, dbCache)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Directly read from cache
|
||||
cachedData := restoreDatabaseCache.Get(restore.ID.String())
|
||||
assert.NotNil(t, cachedData, "Cache entry should exist")
|
||||
assert.NotNil(t, cachedData.PostgresqlDatabase, "PostgreSQL credentials should be cached")
|
||||
|
||||
// Verify password in cache is encrypted (not plaintext)
|
||||
assert.NotEqual(t, plaintextPassword, cachedData.PostgresqlDatabase.Password,
|
||||
"Cached password should be encrypted, not plaintext")
|
||||
assert.Equal(t, postgresDB.Password, cachedData.PostgresqlDatabase.Password,
|
||||
"Cached password should match the encrypted version")
|
||||
|
||||
// Verify other fields are present
|
||||
assert.Equal(t, config.GetEnv().TestLocalhost, cachedData.PostgresqlDatabase.Host)
|
||||
assert.Equal(t, 5432, cachedData.PostgresqlDatabase.Port)
|
||||
assert.Equal(t, "testuser", cachedData.PostgresqlDatabase.Username)
|
||||
assert.Equal(t, "testdb", *cachedData.PostgresqlDatabase.Database)
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_StartRestore_CredentialsRemovedAfterRestoreStarts(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
// Start scheduler so it can handle task assignments
|
||||
scheduler := CreateTestRestoresScheduler()
|
||||
schedulerCancel := StartSchedulerForTest(t, scheduler)
|
||||
defer schedulerCancel()
|
||||
|
||||
// Create mock restorer node with credential capture usecase
|
||||
restorerNode := CreateTestRestorerNode()
|
||||
calledChan := make(chan *databases.Database, 1)
|
||||
restorerNode.restoreBackupUsecase = &MockCaptureCredentialsRestoreUsecase{
|
||||
CalledChan: calledChan,
|
||||
ShouldSucceed: true,
|
||||
}
|
||||
|
||||
cancel := StartRestorerNodeForTest(t, restorerNode)
|
||||
defer StopRestorerNodeForTest(t, cancel, restorerNode)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backupRepo := backups_core.BackupRepository{}
|
||||
backups, _ := backupRepo.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepo.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
restoreRepo := restores_core.RestoreRepository{}
|
||||
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusCompleted)
|
||||
for _, restore := range restores {
|
||||
restoreRepo.DeleteByID(restore.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
|
||||
cache_utils.ClearAllCache()
|
||||
}()
|
||||
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
// Create a test backup
|
||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
||||
|
||||
// Create restore with credentials
|
||||
plaintextPassword := "test_password_456"
|
||||
restore := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
}
|
||||
err := restoreRepository.Save(restore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create PostgreSQL database credentials
|
||||
// Database field is nil to avoid PopulateDbData trying to connect
|
||||
postgresDB := &postgresql.PostgresqlDatabase{
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "testuser",
|
||||
Password: plaintextPassword,
|
||||
Database: nil,
|
||||
Version: "16",
|
||||
}
|
||||
|
||||
// Encrypt password (same as production flow)
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
err = postgresDB.EncryptSensitiveFields(database.ID, encryptor)
|
||||
assert.NoError(t, err)
|
||||
|
||||
encryptedPassword := postgresDB.Password
|
||||
|
||||
// Create cache with encrypted credentials
|
||||
dbCache := &RestoreDatabaseCache{
|
||||
PostgresqlDatabase: postgresDB,
|
||||
}
|
||||
|
||||
// Call StartRestore to cache credentials and trigger restore
|
||||
err = scheduler.StartRestore(restore.ID, dbCache)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Wait for mock usecase to be called (with timeout)
|
||||
var capturedDB *databases.Database
|
||||
select {
|
||||
case capturedDB = <-calledChan:
|
||||
t.Log("Mock usecase was called, credentials captured")
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Fatal("Timeout waiting for mock usecase to be called")
|
||||
}
|
||||
|
||||
// Verify cache is empty after restore starts (credentials were deleted)
|
||||
cacheAfterExecution := restoreDatabaseCache.Get(restore.ID.String())
|
||||
assert.Nil(t, cacheAfterExecution, "Cache should be empty after restore execution starts")
|
||||
|
||||
// Verify mock received valid credentials
|
||||
assert.NotNil(t, capturedDB, "Captured database should not be nil")
|
||||
assert.NotNil(t, capturedDB.Postgresql, "PostgreSQL credentials should be provided to usecase")
|
||||
assert.Equal(t, config.GetEnv().TestLocalhost, capturedDB.Postgresql.Host)
|
||||
assert.Equal(t, 5432, capturedDB.Postgresql.Port)
|
||||
assert.Equal(t, "testuser", capturedDB.Postgresql.Username)
|
||||
assert.NotEmpty(t, capturedDB.Postgresql.Password, "Password should be provided to usecase")
|
||||
|
||||
// Note: Password at this point may still be encrypted because PopulateDbData
|
||||
// is called after the mock captures it. The important thing is that credentials
|
||||
// were provided to the usecase despite cache being deleted.
|
||||
t.Logf("Encrypted password in cache: %s", encryptedPassword)
|
||||
t.Logf("Password received by usecase: %s", capturedDB.Postgresql.Password)
|
||||
|
||||
// Wait for restore to complete
|
||||
WaitForRestoreCompletion(t, restore.ID, 10*time.Second)
|
||||
|
||||
// Verify restore was completed
|
||||
completedRestore, err := restoreRepository.FindByID(restore.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, completedRestore.Status)
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
342
backend/internal/features/restores/restoring/testing.go
Normal file
342
backend/internal/features/restores/restoring/testing.go
Normal file
@@ -0,0 +1,342 @@
|
||||
package restoring
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/restores/usecases"
|
||||
"databasus-backend/internal/features/storages"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
func CreateTestRouter() *gin.Engine {
|
||||
router := workspaces_testing.CreateTestRouter(
|
||||
workspaces_controllers.GetWorkspaceController(),
|
||||
workspaces_controllers.GetMembershipController(),
|
||||
databases.GetDatabaseController(),
|
||||
backups_config.GetBackupConfigController(),
|
||||
)
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
func CreateTestRestorerNode() *RestorerNode {
|
||||
return &RestorerNode{
|
||||
nodeID: uuid.New(),
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
backupService: backups.GetBackupService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
restoreRepository: restoreRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
restoreNodesRegistry: restoreNodesRegistry,
|
||||
logger: logger.GetLogger(),
|
||||
restoreBackupUsecase: usecases.GetRestoreBackupUsecase(),
|
||||
cacheUtil: restoreDatabaseCache,
|
||||
restoreCancelManager: tasks_cancellation.GetTaskCancelManager(),
|
||||
lastHeartbeat: time.Time{},
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateTestRestorerNodeWithUsecase(usecase restores_core.RestoreBackupUsecase) *RestorerNode {
|
||||
return &RestorerNode{
|
||||
nodeID: uuid.New(),
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
backupService: backups.GetBackupService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
restoreRepository: restoreRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
restoreNodesRegistry: restoreNodesRegistry,
|
||||
logger: logger.GetLogger(),
|
||||
restoreBackupUsecase: usecase,
|
||||
cacheUtil: restoreDatabaseCache,
|
||||
restoreCancelManager: tasks_cancellation.GetTaskCancelManager(),
|
||||
lastHeartbeat: time.Time{},
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func CreateTestRestoresScheduler() *RestoresScheduler {
|
||||
return &RestoresScheduler{
|
||||
restoreRepository,
|
||||
backups.GetBackupService(),
|
||||
storages.GetStorageService(),
|
||||
backups_config.GetBackupConfigService(),
|
||||
restoreNodesRegistry,
|
||||
time.Now().UTC(),
|
||||
logger.GetLogger(),
|
||||
make(map[uuid.UUID]RestoreToNodeRelation),
|
||||
restorerNode,
|
||||
restoreDatabaseCache,
|
||||
uuid.Nil,
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
// WaitForRestoreCompletion waits for a restore to be completed (or failed)
|
||||
func WaitForRestoreCompletion(
|
||||
t *testing.T,
|
||||
restoreID uuid.UUID,
|
||||
timeout time.Duration,
|
||||
) {
|
||||
deadline := time.Now().UTC().Add(timeout)
|
||||
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
restore, err := restoreRepository.FindByID(restoreID)
|
||||
if err != nil {
|
||||
t.Logf("WaitForRestoreCompletion: error finding restore: %v", err)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
t.Logf("WaitForRestoreCompletion: restore status: %s", restore.Status)
|
||||
|
||||
if restore.Status == restores_core.RestoreStatusCompleted ||
|
||||
restore.Status == restores_core.RestoreStatusFailed {
|
||||
t.Logf(
|
||||
"WaitForRestoreCompletion: restore finished with status %s",
|
||||
restore.Status,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Logf("WaitForRestoreCompletion: timeout waiting for restore to complete")
|
||||
}
|
||||
|
||||
// StartRestorerNodeForTest starts a RestorerNode in a goroutine for testing.
|
||||
// The node registers itself in the registry and subscribes to restore assignments.
|
||||
// Returns a context cancel function that should be deferred to stop the node.
|
||||
func StartRestorerNodeForTest(t *testing.T, restorerNode *RestorerNode) context.CancelFunc {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
restorerNode.Run(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Poll registry for node presence instead of fixed sleep
|
||||
deadline := time.Now().UTC().Add(5 * time.Second)
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
nodes, err := restoreNodesRegistry.GetAvailableNodes()
|
||||
if err == nil {
|
||||
for _, node := range nodes {
|
||||
if node.ID == restorerNode.nodeID {
|
||||
t.Logf("RestorerNode registered in registry: %s", restorerNode.nodeID)
|
||||
|
||||
return func() {
|
||||
cancel()
|
||||
select {
|
||||
case <-done:
|
||||
t.Log("RestorerNode stopped gracefully")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Log("RestorerNode stop timeout")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Fatalf("RestorerNode failed to register in registry within timeout")
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartSchedulerForTest starts the RestoresScheduler in a goroutine for testing.
|
||||
// The scheduler subscribes to task completions and manages restore lifecycle.
|
||||
// Returns a context cancel function that should be deferred to stop the scheduler.
|
||||
func StartSchedulerForTest(t *testing.T, scheduler *RestoresScheduler) context.CancelFunc {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
scheduler.Run(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Give scheduler time to subscribe to completions
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
t.Log("RestoresScheduler started")
|
||||
|
||||
return func() {
|
||||
cancel()
|
||||
select {
|
||||
case <-done:
|
||||
t.Log("RestoresScheduler stopped gracefully")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Log("RestoresScheduler stop timeout")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StopRestorerNodeForTest stops the RestorerNode by canceling its context.
|
||||
// It waits for the node to unregister from the registry.
|
||||
func StopRestorerNodeForTest(t *testing.T, cancel context.CancelFunc, restorerNode *RestorerNode) {
|
||||
cancel()
|
||||
|
||||
// Wait for node to unregister from registry
|
||||
deadline := time.Now().UTC().Add(2 * time.Second)
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
nodes, err := restoreNodesRegistry.GetAvailableNodes()
|
||||
if err == nil {
|
||||
found := false
|
||||
for _, node := range nodes {
|
||||
if node.ID == restorerNode.nodeID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Logf("RestorerNode unregistered from registry: %s", restorerNode.nodeID)
|
||||
return
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Logf("RestorerNode stop completed for %s", restorerNode.nodeID)
|
||||
}
|
||||
|
||||
func CreateMockNodeInRegistry(nodeID uuid.UUID, throughputMBs int, lastHeartbeat time.Time) error {
|
||||
restoreNode := RestoreNode{
|
||||
ID: nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: lastHeartbeat,
|
||||
}
|
||||
|
||||
return restoreNodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, restoreNode)
|
||||
}
|
||||
|
||||
func UpdateNodeHeartbeatDirectly(
|
||||
nodeID uuid.UUID,
|
||||
throughputMBs int,
|
||||
lastHeartbeat time.Time,
|
||||
) error {
|
||||
restoreNode := RestoreNode{
|
||||
ID: nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: lastHeartbeat,
|
||||
}
|
||||
|
||||
return restoreNodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, restoreNode)
|
||||
}
|
||||
|
||||
func GetNodeFromRegistry(nodeID uuid.UUID) (*RestoreNode, error) {
|
||||
nodes, err := restoreNodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
if node.ID == nodeID {
|
||||
return &node, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("node not found")
|
||||
}
|
||||
|
||||
// WaitForActiveTasksDecrease waits for the active task count to decrease below the initial count.
|
||||
// It polls the registry every 500ms until the count decreases or the timeout is reached.
|
||||
// Returns true if the count decreased, false if timeout was reached.
|
||||
func WaitForActiveTasksDecrease(
|
||||
t *testing.T,
|
||||
nodeID uuid.UUID,
|
||||
initialCount int,
|
||||
timeout time.Duration,
|
||||
) bool {
|
||||
deadline := time.Now().UTC().Add(timeout)
|
||||
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
|
||||
if err != nil {
|
||||
t.Logf("WaitForActiveTasksDecrease: error getting node stats: %v", err)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, stat := range stats {
|
||||
if stat.ID == nodeID {
|
||||
t.Logf(
|
||||
"WaitForActiveTasksDecrease: current active tasks = %d (initial = %d)",
|
||||
stat.ActiveRestores,
|
||||
initialCount,
|
||||
)
|
||||
if stat.ActiveRestores < initialCount {
|
||||
t.Logf(
|
||||
"WaitForActiveTasksDecrease: active tasks decreased from %d to %d",
|
||||
initialCount,
|
||||
stat.ActiveRestores,
|
||||
)
|
||||
return true
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Logf("WaitForActiveTasksDecrease: timeout waiting for active tasks to decrease")
|
||||
return false
|
||||
}
|
||||
|
||||
// CreateTestRestore creates a test restore with the given backup and status
|
||||
func CreateTestRestore(
|
||||
t *testing.T,
|
||||
backup *backups_core.Backup,
|
||||
status restores_core.RestoreStatus,
|
||||
) *restores_core.Restore {
|
||||
restore := &restores_core.Restore{
|
||||
BackupID: backup.ID,
|
||||
Status: status,
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: 5432,
|
||||
Username: "test",
|
||||
Password: "test",
|
||||
Database: stringPtr("testdb"),
|
||||
Version: "16",
|
||||
},
|
||||
}
|
||||
|
||||
err := restoreRepository.Save(restore)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test restore: %v", err)
|
||||
}
|
||||
|
||||
return restore
|
||||
}
|
||||
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
@@ -7,10 +7,11 @@ import (
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/disk"
|
||||
"databasus-backend/internal/features/restores/enums"
|
||||
"databasus-backend/internal/features/restores/models"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/restores/restoring"
|
||||
"databasus-backend/internal/features/restores/usecases"
|
||||
"databasus-backend/internal/features/storages"
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
users_models "databasus-backend/internal/features/users/models"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
@@ -25,7 +26,7 @@ import (
|
||||
|
||||
type RestoreService struct {
|
||||
backupService *backups.BackupService
|
||||
restoreRepository *RestoreRepository
|
||||
restoreRepository *restores_core.RestoreRepository
|
||||
storageService *storages.StorageService
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
restoreBackupUsecase *usecases.RestoreBackupUsecase
|
||||
@@ -35,6 +36,7 @@ type RestoreService struct {
|
||||
auditLogService *audit_logs.AuditLogService
|
||||
fieldEncryptor encryption.FieldEncryptor
|
||||
diskService *disk.DiskService
|
||||
taskCancelManager *tasks_cancellation.TaskCancelManager
|
||||
}
|
||||
|
||||
func (s *RestoreService) OnBeforeBackupRemove(backup *backups_core.Backup) error {
|
||||
@@ -44,7 +46,7 @@ func (s *RestoreService) OnBeforeBackupRemove(backup *backups_core.Backup) error
|
||||
}
|
||||
|
||||
for _, restore := range restores {
|
||||
if restore.Status == enums.RestoreStatusInProgress {
|
||||
if restore.Status == restores_core.RestoreStatusInProgress {
|
||||
return errors.New("restore is in progress, backup cannot be removed")
|
||||
}
|
||||
}
|
||||
@@ -61,7 +63,7 @@ func (s *RestoreService) OnBeforeBackupRemove(backup *backups_core.Backup) error
|
||||
func (s *RestoreService) GetRestores(
|
||||
user *users_models.User,
|
||||
backupID uuid.UUID,
|
||||
) ([]*models.Restore, error) {
|
||||
) ([]*restores_core.Restore, error) {
|
||||
backup, err := s.backupService.GetBackup(backupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -93,7 +95,7 @@ func (s *RestoreService) GetRestores(
|
||||
func (s *RestoreService) RestoreBackupWithAuth(
|
||||
user *users_models.User,
|
||||
backupID uuid.UUID,
|
||||
requestDTO RestoreBackupRequest,
|
||||
requestDTO restores_core.RestoreBackupRequest,
|
||||
) error {
|
||||
backup, err := s.backupService.GetBackup(backupID)
|
||||
if err != nil {
|
||||
@@ -134,11 +136,50 @@ func (s *RestoreService) RestoreBackupWithAuth(
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := s.RestoreBackup(backup, requestDTO); err != nil {
|
||||
s.logger.Error("Failed to restore backup", "error", err)
|
||||
// Validate no parallel restores for the same database
|
||||
if err := s.validateNoParallelRestores(backup.DatabaseID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create restore record with the request configuration
|
||||
restore := restores_core.Restore{
|
||||
ID: uuid.New(),
|
||||
Status: restores_core.RestoreStatusInProgress,
|
||||
BackupID: backup.ID,
|
||||
Backup: backup,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
RestoreDurationMs: 0,
|
||||
FailMessage: nil,
|
||||
PostgresqlDatabase: requestDTO.PostgresqlDatabase,
|
||||
MysqlDatabase: requestDTO.MysqlDatabase,
|
||||
MariadbDatabase: requestDTO.MariadbDatabase,
|
||||
MongodbDatabase: requestDTO.MongodbDatabase,
|
||||
}
|
||||
|
||||
if err := s.restoreRepository.Save(&restore); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Prepare database cache with credentials from the request
|
||||
dbCache := &restoring.RestoreDatabaseCache{
|
||||
PostgresqlDatabase: requestDTO.PostgresqlDatabase,
|
||||
MysqlDatabase: requestDTO.MysqlDatabase,
|
||||
MariadbDatabase: requestDTO.MariadbDatabase,
|
||||
MongodbDatabase: requestDTO.MongodbDatabase,
|
||||
}
|
||||
|
||||
// Trigger restore via scheduler
|
||||
scheduler := restoring.GetRestoresScheduler()
|
||||
if err := scheduler.StartRestore(restore.ID, dbCache); err != nil {
|
||||
// Mark restore as failed if we can't schedule it
|
||||
failMsg := fmt.Sprintf("Failed to schedule restore: %v", err)
|
||||
restore.FailMessage = &failMsg
|
||||
restore.Status = restores_core.RestoreStatusFailed
|
||||
if saveErr := s.restoreRepository.Save(&restore); saveErr != nil {
|
||||
s.logger.Error("Failed to save restore after scheduling error", "error", saveErr)
|
||||
}
|
||||
}()
|
||||
return err
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
@@ -153,127 +194,9 @@ func (s *RestoreService) RestoreBackupWithAuth(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RestoreService) RestoreBackup(
|
||||
backup *backups_core.Backup,
|
||||
requestDTO RestoreBackupRequest,
|
||||
) error {
|
||||
if backup.Status != backups_core.BackupStatusCompleted {
|
||||
return errors.New("backup is not completed")
|
||||
}
|
||||
|
||||
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch database.Type {
|
||||
case databases.DatabaseTypePostgres:
|
||||
if requestDTO.PostgresqlDatabase == nil {
|
||||
return errors.New("postgresql database is required")
|
||||
}
|
||||
case databases.DatabaseTypeMysql:
|
||||
if requestDTO.MysqlDatabase == nil {
|
||||
return errors.New("mysql database is required")
|
||||
}
|
||||
case databases.DatabaseTypeMariadb:
|
||||
if requestDTO.MariadbDatabase == nil {
|
||||
return errors.New("mariadb database is required")
|
||||
}
|
||||
case databases.DatabaseTypeMongodb:
|
||||
if requestDTO.MongodbDatabase == nil {
|
||||
return errors.New("mongodb database is required")
|
||||
}
|
||||
}
|
||||
|
||||
restore := models.Restore{
|
||||
ID: uuid.New(),
|
||||
Status: enums.RestoreStatusInProgress,
|
||||
|
||||
BackupID: backup.ID,
|
||||
Backup: backup,
|
||||
|
||||
CreatedAt: time.Now().UTC(),
|
||||
RestoreDurationMs: 0,
|
||||
|
||||
FailMessage: nil,
|
||||
}
|
||||
|
||||
// Save the restore first
|
||||
if err := s.restoreRepository.Save(&restore); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Save the restore again to include the postgresql database
|
||||
if err := s.restoreRepository.Save(&restore); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
storage, err := s.storageService.GetStorageByID(backup.StorageID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(
|
||||
database.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
start := time.Now().UTC()
|
||||
|
||||
restoringToDB := &databases.Database{
|
||||
Type: database.Type,
|
||||
Postgresql: requestDTO.PostgresqlDatabase,
|
||||
Mysql: requestDTO.MysqlDatabase,
|
||||
Mariadb: requestDTO.MariadbDatabase,
|
||||
Mongodb: requestDTO.MongodbDatabase,
|
||||
}
|
||||
|
||||
if err := restoringToDB.PopulateDbData(s.logger, s.fieldEncryptor); err != nil {
|
||||
return fmt.Errorf("failed to auto-detect database data: %w", err)
|
||||
}
|
||||
|
||||
isExcludeExtensions := false
|
||||
if requestDTO.PostgresqlDatabase != nil {
|
||||
isExcludeExtensions = requestDTO.PostgresqlDatabase.IsExcludeExtensions
|
||||
}
|
||||
|
||||
err = s.restoreBackupUsecase.Execute(
|
||||
backupConfig,
|
||||
restore,
|
||||
database,
|
||||
restoringToDB,
|
||||
backup,
|
||||
storage,
|
||||
isExcludeExtensions,
|
||||
)
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
restore.FailMessage = &errMsg
|
||||
restore.Status = enums.RestoreStatusFailed
|
||||
restore.RestoreDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
if err := s.restoreRepository.Save(&restore); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
restore.Status = enums.RestoreStatusCompleted
|
||||
restore.RestoreDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
if err := s.restoreRepository.Save(&restore); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RestoreService) validateVersionCompatibility(
|
||||
backupDatabase *databases.Database,
|
||||
requestDTO RestoreBackupRequest,
|
||||
requestDTO restores_core.RestoreBackupRequest,
|
||||
) error {
|
||||
// populate version
|
||||
if requestDTO.MariadbDatabase != nil {
|
||||
@@ -372,7 +295,7 @@ func (s *RestoreService) validateVersionCompatibility(
|
||||
|
||||
func (s *RestoreService) validateDiskSpace(
|
||||
backup *backups_core.Backup,
|
||||
requestDTO RestoreBackupRequest,
|
||||
requestDTO restores_core.RestoreBackupRequest,
|
||||
) error {
|
||||
// Only validate disk space for PostgreSQL when file-based restore is needed:
|
||||
// - CPU > 1 (parallel jobs require file)
|
||||
@@ -424,3 +347,71 @@ func (s *RestoreService) validateDiskSpace(
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RestoreService) validateNoParallelRestores(databaseID uuid.UUID) error {
|
||||
inProgressRestores, err := s.restoreRepository.FindInProgressRestoresByDatabaseID(databaseID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check for in-progress restores: %w", err)
|
||||
}
|
||||
|
||||
isInProgress := len(inProgressRestores) > 0
|
||||
if isInProgress {
|
||||
return errors.New(
|
||||
"another restore is already in progress for this database. Please wait for it to complete or cancel it before starting a new restore",
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RestoreService) CancelRestore(
|
||||
user *users_models.User,
|
||||
restoreID uuid.UUID,
|
||||
) error {
|
||||
restore, err := s.restoreRepository.FindByID(restoreID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
backup, err := s.backupService.GetBackup(restore.BackupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if database.WorkspaceID == nil {
|
||||
return errors.New("cannot cancel restore for database without workspace")
|
||||
}
|
||||
|
||||
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !canManage {
|
||||
return errors.New("insufficient permissions to cancel restore for this database")
|
||||
}
|
||||
|
||||
if restore.Status != restores_core.RestoreStatusInProgress {
|
||||
return errors.New("restore is not in progress")
|
||||
}
|
||||
|
||||
if err := s.taskCancelManager.CancelTask(restoreID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
"Restore cancelled for database: %s (ID: %s)",
|
||||
database.Name,
|
||||
restoreID.String(),
|
||||
),
|
||||
&user.ID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
51
backend/internal/features/restores/testing.go
Normal file
51
backend/internal/features/restores/testing.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package restores
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/restores/restoring"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
)
|
||||
|
||||
func CreateTestRouter() *gin.Engine {
|
||||
router := workspaces_testing.CreateTestRouter(
|
||||
workspaces_controllers.GetWorkspaceController(),
|
||||
workspaces_controllers.GetMembershipController(),
|
||||
databases.GetDatabaseController(),
|
||||
backups_config.GetBackupConfigController(),
|
||||
backups.GetBackupController(),
|
||||
GetRestoreController(),
|
||||
)
|
||||
|
||||
v1 := router.Group("/api/v1")
|
||||
backups.GetBackupController().RegisterPublicRoutes(v1)
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
func SetupMockRestoreNode(t *testing.T) (uuid.UUID, context.CancelFunc) {
|
||||
nodeID := uuid.New()
|
||||
err := restoring.CreateMockNodeInRegistry(
|
||||
nodeID,
|
||||
100,
|
||||
time.Now().UTC(),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create mock node: %v", err)
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
// Node will expire naturally from registry
|
||||
}
|
||||
|
||||
return nodeID, cleanup
|
||||
}
|
||||
@@ -24,7 +24,7 @@ import (
|
||||
"databasus-backend/internal/features/databases"
|
||||
mariadbtypes "databasus-backend/internal/features/databases/databases/mariadb"
|
||||
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
|
||||
"databasus-backend/internal/features/restores/models"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/tools"
|
||||
@@ -36,10 +36,11 @@ type RestoreMariadbBackupUsecase struct {
|
||||
}
|
||||
|
||||
func (uc *RestoreMariadbBackupUsecase) Execute(
|
||||
parentCtx context.Context,
|
||||
originalDB *databases.Database,
|
||||
restoringToDB *databases.Database,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore models.Restore,
|
||||
restore restores_core.Restore,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
) error {
|
||||
@@ -79,6 +80,7 @@ func (uc *RestoreMariadbBackupUsecase) Execute(
|
||||
}
|
||||
|
||||
return uc.restoreFromStorage(
|
||||
parentCtx,
|
||||
originalDB,
|
||||
tools.GetMariadbExecutable(
|
||||
tools.MariadbExecutableMariadb,
|
||||
@@ -95,6 +97,7 @@ func (uc *RestoreMariadbBackupUsecase) Execute(
|
||||
}
|
||||
|
||||
func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
|
||||
parentCtx context.Context,
|
||||
database *databases.Database,
|
||||
mariadbBin string,
|
||||
args []string,
|
||||
@@ -103,7 +106,7 @@ func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
|
||||
storage *storages.Storage,
|
||||
mdbConfig *mariadbtypes.MariadbDatabase,
|
||||
) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
@@ -113,6 +116,9 @@ func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-parentCtx.Done():
|
||||
cancel()
|
||||
return
|
||||
case <-ticker.C:
|
||||
if config.IsShouldShutdown() {
|
||||
cancel()
|
||||
@@ -213,6 +219,15 @@ func (uc *RestoreMariadbBackupUsecase) executeMariadbRestore(
|
||||
waitErr := cmd.Wait()
|
||||
stderrOutput := <-stderrCh
|
||||
|
||||
// Check for cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if errors.Is(ctx.Err(), context.Canceled) {
|
||||
return fmt.Errorf("restore cancelled")
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
return fmt.Errorf("restore cancelled due to shutdown")
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
"databasus-backend/internal/features/databases"
|
||||
mongodbtypes "databasus-backend/internal/features/databases/databases/mongodb"
|
||||
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
|
||||
"databasus-backend/internal/features/restores/models"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/tools"
|
||||
@@ -36,10 +36,11 @@ type RestoreMongodbBackupUsecase struct {
|
||||
}
|
||||
|
||||
func (uc *RestoreMongodbBackupUsecase) Execute(
|
||||
parentCtx context.Context,
|
||||
originalDB *databases.Database,
|
||||
restoringToDB *databases.Database,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore models.Restore,
|
||||
restore restores_core.Restore,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
) error {
|
||||
@@ -76,6 +77,7 @@ func (uc *RestoreMongodbBackupUsecase) Execute(
|
||||
args := uc.buildMongorestoreArgs(mdb, decryptedPassword, sourceDatabase)
|
||||
|
||||
return uc.restoreFromStorage(
|
||||
parentCtx,
|
||||
tools.GetMongodbExecutable(
|
||||
tools.MongodbExecutableMongorestore,
|
||||
config.GetEnv().EnvMode,
|
||||
@@ -122,12 +124,13 @@ func (uc *RestoreMongodbBackupUsecase) buildMongorestoreArgs(
|
||||
}
|
||||
|
||||
func (uc *RestoreMongodbBackupUsecase) restoreFromStorage(
|
||||
parentCtx context.Context,
|
||||
mongorestoreBin string,
|
||||
args []string,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), restoreTimeout)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, restoreTimeout)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
@@ -137,6 +140,9 @@ func (uc *RestoreMongodbBackupUsecase) restoreFromStorage(
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-parentCtx.Done():
|
||||
cancel()
|
||||
return
|
||||
case <-ticker.C:
|
||||
if config.IsShouldShutdown() {
|
||||
cancel()
|
||||
@@ -218,6 +224,15 @@ func (uc *RestoreMongodbBackupUsecase) executeMongoRestore(
|
||||
waitErr := cmd.Wait()
|
||||
stderrOutput := <-stderrCh
|
||||
|
||||
// Check for cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if errors.Is(ctx.Err(), context.Canceled) {
|
||||
return fmt.Errorf("restore cancelled")
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
return fmt.Errorf("restore cancelled due to shutdown")
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ import (
|
||||
"databasus-backend/internal/features/databases"
|
||||
mysqltypes "databasus-backend/internal/features/databases/databases/mysql"
|
||||
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
|
||||
"databasus-backend/internal/features/restores/models"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/tools"
|
||||
@@ -36,10 +36,11 @@ type RestoreMysqlBackupUsecase struct {
|
||||
}
|
||||
|
||||
func (uc *RestoreMysqlBackupUsecase) Execute(
|
||||
parentCtx context.Context,
|
||||
originalDB *databases.Database,
|
||||
restoringToDB *databases.Database,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore models.Restore,
|
||||
restore restores_core.Restore,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
) error {
|
||||
@@ -78,6 +79,7 @@ func (uc *RestoreMysqlBackupUsecase) Execute(
|
||||
}
|
||||
|
||||
return uc.restoreFromStorage(
|
||||
parentCtx,
|
||||
originalDB,
|
||||
tools.GetMysqlExecutable(
|
||||
my.Version,
|
||||
@@ -94,6 +96,7 @@ func (uc *RestoreMysqlBackupUsecase) Execute(
|
||||
}
|
||||
|
||||
func (uc *RestoreMysqlBackupUsecase) restoreFromStorage(
|
||||
parentCtx context.Context,
|
||||
database *databases.Database,
|
||||
mysqlBin string,
|
||||
args []string,
|
||||
@@ -102,7 +105,7 @@ func (uc *RestoreMysqlBackupUsecase) restoreFromStorage(
|
||||
storage *storages.Storage,
|
||||
myConfig *mysqltypes.MysqlDatabase,
|
||||
) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
@@ -112,6 +115,9 @@ func (uc *RestoreMysqlBackupUsecase) restoreFromStorage(
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-parentCtx.Done():
|
||||
cancel()
|
||||
return
|
||||
case <-ticker.C:
|
||||
if config.IsShouldShutdown() {
|
||||
cancel()
|
||||
@@ -204,6 +210,15 @@ func (uc *RestoreMysqlBackupUsecase) executeMysqlRestore(
|
||||
waitErr := cmd.Wait()
|
||||
stderrOutput := <-stderrCh
|
||||
|
||||
// Check for cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if errors.Is(ctx.Err(), context.Canceled) {
|
||||
return fmt.Errorf("restore cancelled")
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
return fmt.Errorf("restore cancelled due to shutdown")
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
"databasus-backend/internal/features/databases"
|
||||
pgtypes "databasus-backend/internal/features/databases/databases/postgresql"
|
||||
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
|
||||
"databasus-backend/internal/features/restores/models"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/tools"
|
||||
@@ -35,10 +35,11 @@ type RestorePostgresqlBackupUsecase struct {
|
||||
}
|
||||
|
||||
func (uc *RestorePostgresqlBackupUsecase) Execute(
|
||||
parentCtx context.Context,
|
||||
originalDB *databases.Database,
|
||||
restoringToDB *databases.Database,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore models.Restore,
|
||||
restore restores_core.Restore,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
isExcludeExtensions bool,
|
||||
@@ -73,6 +74,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
|
||||
|
||||
// All PostgreSQL backups are now custom format (-Fc)
|
||||
return uc.restoreCustomType(
|
||||
parentCtx,
|
||||
originalDB,
|
||||
pgBin,
|
||||
backup,
|
||||
@@ -84,6 +86,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
|
||||
|
||||
// restoreCustomType restores a backup in custom type (-Fc)
|
||||
func (uc *RestorePostgresqlBackupUsecase) restoreCustomType(
|
||||
parentCtx context.Context,
|
||||
originalDB *databases.Database,
|
||||
pgBin string,
|
||||
backup *backups_core.Backup,
|
||||
@@ -102,15 +105,24 @@ func (uc *RestorePostgresqlBackupUsecase) restoreCustomType(
|
||||
// If excluding extensions, we must use file-based restore (requires TOC file generation)
|
||||
// Also use file-based restore for parallel jobs (multiple CPUs)
|
||||
if isExcludeExtensions || pg.CpuCount > 1 {
|
||||
return uc.restoreViaFile(originalDB, pgBin, backup, storage, pg, isExcludeExtensions)
|
||||
return uc.restoreViaFile(
|
||||
parentCtx,
|
||||
originalDB,
|
||||
pgBin,
|
||||
backup,
|
||||
storage,
|
||||
pg,
|
||||
isExcludeExtensions,
|
||||
)
|
||||
}
|
||||
|
||||
// Single CPU without extension exclusion: stream directly via stdin
|
||||
return uc.restoreViaStdin(originalDB, pgBin, backup, storage, pg)
|
||||
return uc.restoreViaStdin(parentCtx, originalDB, pgBin, backup, storage, pg)
|
||||
}
|
||||
|
||||
// restoreViaStdin streams backup via stdin for single CPU restore
|
||||
func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
|
||||
parentCtx context.Context,
|
||||
originalDB *databases.Database,
|
||||
pgBin string,
|
||||
backup *backups_core.Backup,
|
||||
@@ -133,10 +145,10 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
|
||||
"--no-acl",
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// Monitor for shutdown and cancel context if needed
|
||||
// Monitor for shutdown and parent cancellation
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
@@ -145,6 +157,9 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-parentCtx.Done():
|
||||
cancel()
|
||||
return
|
||||
case <-ticker.C:
|
||||
if config.IsShouldShutdown() {
|
||||
cancel()
|
||||
@@ -296,6 +311,15 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
|
||||
stderrOutput := <-stderrCh
|
||||
copyErr := <-copyErrCh
|
||||
|
||||
// Check for cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if errors.Is(ctx.Err(), context.Canceled) {
|
||||
return fmt.Errorf("restore cancelled")
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
// Check for shutdown before finalizing
|
||||
if config.IsShouldShutdown() {
|
||||
return fmt.Errorf("restore cancelled due to shutdown")
|
||||
@@ -307,6 +331,15 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
|
||||
}
|
||||
|
||||
if waitErr != nil {
|
||||
// Check for cancellation again
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if errors.Is(ctx.Err(), context.Canceled) {
|
||||
return fmt.Errorf("restore cancelled")
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
return fmt.Errorf("restore cancelled due to shutdown")
|
||||
}
|
||||
@@ -319,6 +352,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
|
||||
|
||||
// restoreViaFile downloads backup and uses parallel jobs for multi-CPU restore
|
||||
func (uc *RestorePostgresqlBackupUsecase) restoreViaFile(
|
||||
parentCtx context.Context,
|
||||
originalDB *databases.Database,
|
||||
pgBin string,
|
||||
backup *backups_core.Backup,
|
||||
@@ -354,6 +388,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaFile(
|
||||
}
|
||||
|
||||
return uc.restoreFromStorage(
|
||||
parentCtx,
|
||||
originalDB,
|
||||
pgBin,
|
||||
args,
|
||||
@@ -367,6 +402,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaFile(
|
||||
|
||||
// restoreFromStorage restores backup data from storage using pg_restore
|
||||
func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
|
||||
parentCtx context.Context,
|
||||
database *databases.Database,
|
||||
pgBin string,
|
||||
args []string,
|
||||
@@ -386,10 +422,10 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
|
||||
isExcludeExtensions,
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// Monitor for shutdown and cancel context if needed
|
||||
// Monitor for shutdown and parent cancellation
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
@@ -398,6 +434,9 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-parentCtx.Done():
|
||||
cancel()
|
||||
return
|
||||
case <-ticker.C:
|
||||
if config.IsShouldShutdown() {
|
||||
cancel()
|
||||
@@ -624,12 +663,30 @@ func (uc *RestorePostgresqlBackupUsecase) executePgRestore(
|
||||
waitErr := cmd.Wait()
|
||||
stderrOutput := <-stderrCh
|
||||
|
||||
// Check for cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if errors.Is(ctx.Err(), context.Canceled) {
|
||||
return fmt.Errorf("restore cancelled")
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
// Check for shutdown before finalizing
|
||||
if config.IsShouldShutdown() {
|
||||
return fmt.Errorf("restore cancelled due to shutdown")
|
||||
}
|
||||
|
||||
if waitErr != nil {
|
||||
// Check for cancellation again
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if errors.Is(ctx.Err(), context.Canceled) {
|
||||
return fmt.Errorf("restore cancelled")
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
return fmt.Errorf("restore cancelled due to shutdown")
|
||||
}
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
package usecases
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/restores/models"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
usecases_mariadb "databasus-backend/internal/features/restores/usecases/mariadb"
|
||||
usecases_mongodb "databasus-backend/internal/features/restores/usecases/mongodb"
|
||||
usecases_mysql "databasus-backend/internal/features/restores/usecases/mysql"
|
||||
@@ -22,8 +23,9 @@ type RestoreBackupUsecase struct {
|
||||
}
|
||||
|
||||
func (uc *RestoreBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore models.Restore,
|
||||
restore restores_core.Restore,
|
||||
originalDB *databases.Database,
|
||||
restoringToDB *databases.Database,
|
||||
backup *backups_core.Backup,
|
||||
@@ -33,6 +35,7 @@ func (uc *RestoreBackupUsecase) Execute(
|
||||
switch originalDB.Type {
|
||||
case databases.DatabaseTypePostgres:
|
||||
return uc.restorePostgresqlBackupUsecase.Execute(
|
||||
ctx,
|
||||
originalDB,
|
||||
restoringToDB,
|
||||
backupConfig,
|
||||
@@ -43,6 +46,7 @@ func (uc *RestoreBackupUsecase) Execute(
|
||||
)
|
||||
case databases.DatabaseTypeMysql:
|
||||
return uc.restoreMysqlBackupUsecase.Execute(
|
||||
ctx,
|
||||
originalDB,
|
||||
restoringToDB,
|
||||
backupConfig,
|
||||
@@ -52,6 +56,7 @@ func (uc *RestoreBackupUsecase) Execute(
|
||||
)
|
||||
case databases.DatabaseTypeMariadb:
|
||||
return uc.restoreMariadbBackupUsecase.Execute(
|
||||
ctx,
|
||||
originalDB,
|
||||
restoringToDB,
|
||||
backupConfig,
|
||||
@@ -61,6 +66,7 @@ func (uc *RestoreBackupUsecase) Execute(
|
||||
)
|
||||
case databases.DatabaseTypeMongodb:
|
||||
return uc.restoreMongodbBackupUsecase.Execute(
|
||||
ctx,
|
||||
originalDB,
|
||||
restoringToDB,
|
||||
backupConfig,
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
azure_blob_storage "databasus-backend/internal/features/storages/models/azure_blob"
|
||||
ftp_storage "databasus-backend/internal/features/storages/models/ftp"
|
||||
@@ -902,6 +903,12 @@ func Test_StorageSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Skip Google Drive tests if external resources tests are disabled
|
||||
if tc.storageType == StorageTypeGoogleDrive &&
|
||||
config.GetEnv().IsSkipExternalResourcesTests {
|
||||
t.Skip("Skipping Google Drive storage test: IS_SKIP_EXTERNAL_RESOURCES_TESTS=true")
|
||||
}
|
||||
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
package storages
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var storageRepository = &StorageRepository{}
|
||||
@@ -27,6 +31,21 @@ func GetStorageController() *StorageController {
|
||||
return storageController
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(storageService)
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(storageService)
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,7 +113,7 @@ func Test_Storage_BasicOperations(t *testing.T) {
|
||||
name: "NASStorage",
|
||||
storage: &nas_storage.NASStorage{
|
||||
StorageID: uuid.New(),
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: nasPort,
|
||||
Share: "backups",
|
||||
Username: "testuser",
|
||||
@@ -147,7 +147,7 @@ func Test_Storage_BasicOperations(t *testing.T) {
|
||||
name: "FTPStorage",
|
||||
storage: &ftp_storage.FTPStorage{
|
||||
StorageID: uuid.New(),
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: ftpPort,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
@@ -159,7 +159,7 @@ func Test_Storage_BasicOperations(t *testing.T) {
|
||||
name: "SFTPStorage",
|
||||
storage: &sftp_storage.SFTPStorage{
|
||||
StorageID: uuid.New(),
|
||||
Host: "localhost",
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: sftpPort,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
@@ -185,7 +185,9 @@ acl = private`, s3Container.accessKey, s3Container.secretKey, s3Container.endpoi
|
||||
|
||||
// Add Google Drive storage test only if environment variables are available
|
||||
env := config.GetEnv()
|
||||
if env.TestGoogleDriveClientID != "" && env.TestGoogleDriveClientSecret != "" &&
|
||||
if env.IsSkipExternalResourcesTests {
|
||||
t.Log("Skipping Google Drive storage test: IS_SKIP_EXTERNAL_RESOURCES_TESTS=true")
|
||||
} else if env.TestGoogleDriveClientID != "" && env.TestGoogleDriveClientSecret != "" &&
|
||||
env.TestGoogleDriveTokenJSON != "" {
|
||||
testCases = append(testCases, struct {
|
||||
name string
|
||||
@@ -297,7 +299,7 @@ func setupS3Container(ctx context.Context) (*S3Container, error) {
|
||||
secretKey := "testpassword"
|
||||
bucketName := "test-bucket"
|
||||
region := "us-east-1"
|
||||
endpoint := fmt.Sprintf("127.0.0.1:%s", env.TestMinioPort)
|
||||
endpoint := fmt.Sprintf("%s:%s", env.TestLocalhost, env.TestMinioPort)
|
||||
|
||||
// Create MinIO client and ensure bucket exists
|
||||
minioClient, err := minio.New(endpoint, &minio.Options{
|
||||
@@ -343,15 +345,21 @@ func setupAzuriteContainer(ctx context.Context) (*AzuriteContainer, error) {
|
||||
accountName := "devstoreaccount1"
|
||||
// this is real testing key for azurite, it's not a real key
|
||||
accountKey := "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="
|
||||
serviceURL := fmt.Sprintf("http://127.0.0.1:%s/%s", env.TestAzuriteBlobPort, accountName)
|
||||
serviceURL := fmt.Sprintf(
|
||||
"http://%s:%s/%s",
|
||||
env.TestLocalhost,
|
||||
env.TestAzuriteBlobPort,
|
||||
accountName,
|
||||
)
|
||||
containerNameKey := "test-container-key"
|
||||
containerNameStr := "test-container-connstr"
|
||||
|
||||
// Build explicit connection string for Azurite
|
||||
connectionString := fmt.Sprintf(
|
||||
"DefaultEndpointsProtocol=http;AccountName=%s;AccountKey=%s;BlobEndpoint=http://127.0.0.1:%s/%s",
|
||||
"DefaultEndpointsProtocol=http;AccountName=%s;AccountKey=%s;BlobEndpoint=http://%s:%s/%s",
|
||||
accountName,
|
||||
accountKey,
|
||||
env.TestLocalhost,
|
||||
env.TestAzuriteBlobPort,
|
||||
accountName,
|
||||
)
|
||||
|
||||
@@ -285,30 +285,30 @@ func (f *FTPStorage) ensureDirectory(conn *ftp.ServerConn, path string) error {
|
||||
}
|
||||
|
||||
parts := strings.Split(path, "/")
|
||||
currentPath := ""
|
||||
|
||||
currentDir, err := conn.CurrentDir()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get current directory: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.ChangeDir(currentDir)
|
||||
}()
|
||||
|
||||
for _, part := range parts {
|
||||
if part == "" || part == "." {
|
||||
continue
|
||||
}
|
||||
|
||||
if currentPath == "" {
|
||||
currentPath = part
|
||||
} else {
|
||||
currentPath = currentPath + "/" + part
|
||||
}
|
||||
|
||||
err := conn.ChangeDir(currentPath)
|
||||
err := conn.ChangeDir(part)
|
||||
if err != nil {
|
||||
err = conn.MakeDir(currentPath)
|
||||
err = conn.MakeDir(part)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create directory '%s': %w", currentPath, err)
|
||||
return fmt.Errorf("failed to create directory '%s': %w", part, err)
|
||||
}
|
||||
err = conn.ChangeDir(part)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to change into directory '%s': %w", part, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = conn.ChangeDirToParent()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to change to parent directory: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ func (s *HealthcheckService) IsHealthy() error {
|
||||
}
|
||||
}
|
||||
|
||||
if config.GetEnv().IsBackupNode {
|
||||
if config.GetEnv().IsProcessingNode {
|
||||
if !s.backuperNode.IsBackuperRunning() {
|
||||
return errors.New("backuper node is not running for more than 5 minutes")
|
||||
}
|
||||
|
||||
@@ -2,9 +2,11 @@ package task_cancellation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/logger"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -20,6 +22,21 @@ func GetTaskCancelManager() *TaskCancelManager {
|
||||
return taskCancelManager
|
||||
}
|
||||
|
||||
var (
|
||||
setupOnce sync.Once
|
||||
isSetup atomic.Bool
|
||||
)
|
||||
|
||||
func SetupDependencies() {
|
||||
taskCancelManager.StartSubscription()
|
||||
wasAlreadySetup := isSetup.Load()
|
||||
|
||||
setupOnce.Do(func() {
|
||||
taskCancelManager.StartSubscription()
|
||||
|
||||
isSetup.Store(true)
|
||||
})
|
||||
|
||||
if wasAlreadySetup {
|
||||
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
package task_registry
|
||||
|
||||
import (
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var taskNodesRegistry = &TaskNodesRegistry{
|
||||
cache_utils.GetValkeyClient(),
|
||||
logger.GetLogger(),
|
||||
cache_utils.DefaultCacheTimeout,
|
||||
cache_utils.NewPubSubManager(),
|
||||
cache_utils.NewPubSubManager(),
|
||||
}
|
||||
|
||||
func GetTaskNodesRegistry() *TaskNodesRegistry {
|
||||
return taskNodesRegistry
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
package task_registry
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type TaskNode struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
ThroughputMBs int `json:"throughputMBs"`
|
||||
LastHeartbeat time.Time `json:"lastHeartbeat"`
|
||||
}
|
||||
|
||||
type TaskNodeStats struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
ActiveTasks int `json:"activeTasks"`
|
||||
}
|
||||
|
||||
type TaskSubmitMessage struct {
|
||||
NodeID string `json:"nodeId"`
|
||||
TaskID string `json:"taskId"`
|
||||
IsCallNotifier bool `json:"isCallNotifier"`
|
||||
}
|
||||
|
||||
type TaskCompletionMessage struct {
|
||||
NodeID string `json:"nodeId"`
|
||||
TaskID string `json:"taskId"`
|
||||
}
|
||||
159
backend/internal/features/test_once_protection.go
Normal file
159
backend/internal/features/test_once_protection.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package features
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/restores"
|
||||
"databasus-backend/internal/features/restores/restoring"
|
||||
"databasus-backend/internal/features/storages"
|
||||
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
)
|
||||
|
||||
// Test_SetupDependencies_CalledTwice_LogsWarning verifies SetupDependencies is idempotent
|
||||
func Test_SetupDependencies_CalledTwice_LogsWarning(t *testing.T) {
|
||||
// Call each SetupDependencies twice - should not panic, only log warnings
|
||||
audit_logs.SetupDependencies()
|
||||
audit_logs.SetupDependencies()
|
||||
|
||||
backups.SetupDependencies()
|
||||
backups.SetupDependencies()
|
||||
|
||||
backups_config.SetupDependencies()
|
||||
backups_config.SetupDependencies()
|
||||
|
||||
databases.SetupDependencies()
|
||||
databases.SetupDependencies()
|
||||
|
||||
healthcheck_config.SetupDependencies()
|
||||
healthcheck_config.SetupDependencies()
|
||||
|
||||
notifiers.SetupDependencies()
|
||||
notifiers.SetupDependencies()
|
||||
|
||||
restores.SetupDependencies()
|
||||
restores.SetupDependencies()
|
||||
|
||||
storages.SetupDependencies()
|
||||
storages.SetupDependencies()
|
||||
|
||||
task_cancellation.SetupDependencies()
|
||||
task_cancellation.SetupDependencies()
|
||||
|
||||
// If we reach here without panic, test passes
|
||||
t.Log("All SetupDependencies calls completed successfully (idempotent)")
|
||||
}
|
||||
|
||||
// Test_SetupDependencies_ConcurrentCalls_Safe verifies thread safety
|
||||
func Test_SetupDependencies_ConcurrentCalls_Safe(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Call SetupDependencies concurrently from 10 goroutines
|
||||
for range 10 {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
audit_logs.SetupDependencies()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
t.Log("Concurrent SetupDependencies calls completed successfully")
|
||||
}
|
||||
|
||||
// Test_BackgroundService_Run_CalledTwice_Panics verifies Run() panics on duplicate calls
|
||||
func Test_BackgroundService_Run_CalledTwice_Panics(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Create a test background service
|
||||
backgroundService := audit_logs.GetAuditLogBackgroundService()
|
||||
|
||||
// Start first Run() in goroutine
|
||||
go func() {
|
||||
backgroundService.Run(ctx)
|
||||
}()
|
||||
|
||||
// Give first call time to initialize
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Second call should panic
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
expectedMsg := "*audit_logs.AuditLogBackgroundService.Run() called multiple times"
|
||||
panicMsg := fmt.Sprintf("%v", r)
|
||||
if panicMsg == expectedMsg {
|
||||
t.Logf("Successfully caught panic: %v", r)
|
||||
} else {
|
||||
t.Errorf("Expected panic message '%s', got '%s'", expectedMsg, panicMsg)
|
||||
}
|
||||
} else {
|
||||
t.Error("Expected panic on second Run() call, but did not panic")
|
||||
}
|
||||
}()
|
||||
|
||||
backgroundService.Run(ctx)
|
||||
}
|
||||
|
||||
// Test_BackupsScheduler_Run_CalledTwice_Panics verifies scheduler panics on duplicate calls
|
||||
func Test_BackupsScheduler_Run_CalledTwice_Panics(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
scheduler := backuping.GetBackupsScheduler()
|
||||
|
||||
// Start first Run() in goroutine
|
||||
go func() {
|
||||
scheduler.Run(ctx)
|
||||
}()
|
||||
|
||||
// Give first call time to initialize
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Second call should panic
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Logf("Successfully caught panic: %v", r)
|
||||
} else {
|
||||
t.Error("Expected panic on second Run() call, but did not panic")
|
||||
}
|
||||
}()
|
||||
|
||||
scheduler.Run(ctx)
|
||||
}
|
||||
|
||||
// Test_RestoresScheduler_Run_CalledTwice_Panics verifies restore scheduler panics on duplicate calls
|
||||
func Test_RestoresScheduler_Run_CalledTwice_Panics(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
|
||||
scheduler := restoring.GetRestoresScheduler()
|
||||
|
||||
// Start first Run() in goroutine
|
||||
go func() {
|
||||
scheduler.Run(ctx)
|
||||
}()
|
||||
|
||||
// Give first call time to initialize
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Second call should panic
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Logf("Successfully caught panic: %v", r)
|
||||
} else {
|
||||
t.Error("Expected panic on second Run() call, but did not panic")
|
||||
}
|
||||
}()
|
||||
|
||||
scheduler.Run(ctx)
|
||||
}
|
||||
@@ -21,9 +21,7 @@ import (
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
mariadbtypes "databasus-backend/internal/features/databases/databases/mariadb"
|
||||
"databasus-backend/internal/features/restores"
|
||||
restores_enums "databasus-backend/internal/features/restores/enums"
|
||||
restores_models "databasus-backend/internal/features/restores/models"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
@@ -213,7 +211,7 @@ func testMariadbBackupRestoreForVersion(
|
||||
)
|
||||
|
||||
restore := waitForMariadbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var tableExists int
|
||||
err = newDB.Get(
|
||||
@@ -311,7 +309,7 @@ func testMariadbBackupRestoreWithEncryptionForVersion(
|
||||
)
|
||||
|
||||
restore := waitForMariadbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var tableExists int
|
||||
err = newDB.Get(
|
||||
@@ -418,7 +416,7 @@ func testMariadbBackupRestoreWithReadOnlyUserForVersion(
|
||||
)
|
||||
|
||||
restore := waitForMariadbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var tableExists int
|
||||
err = newDB.Get(
|
||||
@@ -506,7 +504,7 @@ func createMariadbRestoreViaAPI(
|
||||
version tools.MariadbVersion,
|
||||
token string,
|
||||
) {
|
||||
request := restores.RestoreBackupRequest{
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
MariadbDatabase: &mariadbtypes.MariadbDatabase{
|
||||
Host: host,
|
||||
Port: port,
|
||||
@@ -533,7 +531,7 @@ func waitForMariadbRestoreCompletion(
|
||||
backupID uuid.UUID,
|
||||
token string,
|
||||
timeout time.Duration,
|
||||
) *restores_models.Restore {
|
||||
) *restores_core.Restore {
|
||||
startTime := time.Now()
|
||||
pollInterval := 500 * time.Millisecond
|
||||
|
||||
@@ -542,7 +540,7 @@ func waitForMariadbRestoreCompletion(
|
||||
t.Fatalf("Timeout waiting for MariaDB restore completion after %v", timeout)
|
||||
}
|
||||
|
||||
var restoresList []*restores_models.Restore
|
||||
var restoresList []*restores_core.Restore
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
@@ -553,10 +551,10 @@ func waitForMariadbRestoreCompletion(
|
||||
)
|
||||
|
||||
for _, restore := range restoresList {
|
||||
if restore.Status == restores_enums.RestoreStatusCompleted {
|
||||
if restore.Status == restores_core.RestoreStatusCompleted {
|
||||
return restore
|
||||
}
|
||||
if restore.Status == restores_enums.RestoreStatusFailed {
|
||||
if restore.Status == restores_core.RestoreStatusFailed {
|
||||
failMsg := "unknown error"
|
||||
if restore.FailMessage != nil {
|
||||
failMsg = *restore.FailMessage
|
||||
@@ -607,7 +605,7 @@ func connectToMariadbContainer(
|
||||
dbName := "testdb"
|
||||
password := "rootpassword"
|
||||
username := "root"
|
||||
host := "127.0.0.1"
|
||||
host := config.GetEnv().TestLocalhost
|
||||
|
||||
portInt, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
|
||||
@@ -23,9 +23,7 @@ import (
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
mongodbtypes "databasus-backend/internal/features/databases/databases/mongodb"
|
||||
"databasus-backend/internal/features/restores"
|
||||
restores_enums "databasus-backend/internal/features/restores/enums"
|
||||
restores_models "databasus-backend/internal/features/restores/models"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
@@ -175,7 +173,7 @@ func testMongodbBackupRestoreForVersion(
|
||||
)
|
||||
|
||||
restore := waitForMongodbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
verifyMongodbDataIntegrity(t, container, newDBName)
|
||||
|
||||
@@ -254,7 +252,7 @@ func testMongodbBackupRestoreWithEncryptionForVersion(
|
||||
)
|
||||
|
||||
restore := waitForMongodbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
verifyMongodbDataIntegrity(t, container, newDBName)
|
||||
|
||||
@@ -342,7 +340,7 @@ func testMongodbBackupRestoreWithReadOnlyUserForVersion(
|
||||
)
|
||||
|
||||
restore := waitForMongodbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
verifyMongodbDataIntegrity(t, container, newDBName)
|
||||
|
||||
@@ -431,7 +429,7 @@ func createMongodbRestoreViaAPI(
|
||||
version tools.MongodbVersion,
|
||||
token string,
|
||||
) {
|
||||
request := restores.RestoreBackupRequest{
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
MongodbDatabase: &mongodbtypes.MongodbDatabase{
|
||||
Host: host,
|
||||
Port: port,
|
||||
@@ -461,7 +459,7 @@ func waitForMongodbRestoreCompletion(
|
||||
backupID uuid.UUID,
|
||||
token string,
|
||||
timeout time.Duration,
|
||||
) *restores_models.Restore {
|
||||
) *restores_core.Restore {
|
||||
startTime := time.Now()
|
||||
pollInterval := 500 * time.Millisecond
|
||||
|
||||
@@ -470,7 +468,7 @@ func waitForMongodbRestoreCompletion(
|
||||
t.Fatalf("Timeout waiting for MongoDB restore completion after %v", timeout)
|
||||
}
|
||||
|
||||
var restoresList []*restores_models.Restore
|
||||
var restoresList []*restores_core.Restore
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
@@ -481,10 +479,10 @@ func waitForMongodbRestoreCompletion(
|
||||
)
|
||||
|
||||
for _, restore := range restoresList {
|
||||
if restore.Status == restores_enums.RestoreStatusCompleted {
|
||||
if restore.Status == restores_core.RestoreStatusCompleted {
|
||||
return restore
|
||||
}
|
||||
if restore.Status == restores_enums.RestoreStatusFailed {
|
||||
if restore.Status == restores_core.RestoreStatusFailed {
|
||||
failMsg := "unknown error"
|
||||
if restore.FailMessage != nil {
|
||||
failMsg = *restore.FailMessage
|
||||
@@ -552,7 +550,7 @@ func connectToMongodbContainer(
|
||||
password := "rootpassword"
|
||||
username := "root"
|
||||
authDatabase := "admin"
|
||||
host := "127.0.0.1"
|
||||
host := config.GetEnv().TestLocalhost
|
||||
|
||||
portInt, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
@@ -560,8 +558,13 @@ func connectToMongodbContainer(
|
||||
}
|
||||
|
||||
uri := fmt.Sprintf(
|
||||
"mongodb://%s:%s@%s:%d/%s?authSource=%s",
|
||||
username, password, host, portInt, dbName, authDatabase,
|
||||
"mongodb://%s:%s@%s:%d/%s?authSource=%s&serverSelectionTimeoutMS=5000&connectTimeoutMS=5000",
|
||||
username,
|
||||
password,
|
||||
host,
|
||||
portInt,
|
||||
dbName,
|
||||
authDatabase,
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
|
||||
@@ -21,9 +21,7 @@ import (
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
mysqltypes "databasus-backend/internal/features/databases/databases/mysql"
|
||||
"databasus-backend/internal/features/restores"
|
||||
restores_enums "databasus-backend/internal/features/restores/enums"
|
||||
restores_models "databasus-backend/internal/features/restores/models"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
@@ -188,7 +186,7 @@ func testMysqlBackupRestoreForVersion(t *testing.T, mysqlVersion tools.MysqlVers
|
||||
)
|
||||
|
||||
restore := waitForMysqlRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var tableExists int
|
||||
err = newDB.Get(
|
||||
@@ -286,7 +284,7 @@ func testMysqlBackupRestoreWithEncryptionForVersion(
|
||||
)
|
||||
|
||||
restore := waitForMysqlRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var tableExists int
|
||||
err = newDB.Get(
|
||||
@@ -393,7 +391,7 @@ func testMysqlBackupRestoreWithReadOnlyUserForVersion(
|
||||
)
|
||||
|
||||
restore := waitForMysqlRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var tableExists int
|
||||
err = newDB.Get(
|
||||
@@ -481,7 +479,7 @@ func createMysqlRestoreViaAPI(
|
||||
version tools.MysqlVersion,
|
||||
token string,
|
||||
) {
|
||||
request := restores.RestoreBackupRequest{
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
MysqlDatabase: &mysqltypes.MysqlDatabase{
|
||||
Host: host,
|
||||
Port: port,
|
||||
@@ -508,7 +506,7 @@ func waitForMysqlRestoreCompletion(
|
||||
backupID uuid.UUID,
|
||||
token string,
|
||||
timeout time.Duration,
|
||||
) *restores_models.Restore {
|
||||
) *restores_core.Restore {
|
||||
startTime := time.Now()
|
||||
pollInterval := 500 * time.Millisecond
|
||||
|
||||
@@ -517,7 +515,7 @@ func waitForMysqlRestoreCompletion(
|
||||
t.Fatalf("Timeout waiting for MySQL restore completion after %v", timeout)
|
||||
}
|
||||
|
||||
var restoresList []*restores_models.Restore
|
||||
var restoresList []*restores_core.Restore
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
@@ -528,10 +526,10 @@ func waitForMysqlRestoreCompletion(
|
||||
)
|
||||
|
||||
for _, restore := range restoresList {
|
||||
if restore.Status == restores_enums.RestoreStatusCompleted {
|
||||
if restore.Status == restores_core.RestoreStatusCompleted {
|
||||
return restore
|
||||
}
|
||||
if restore.Status == restores_enums.RestoreStatusFailed {
|
||||
if restore.Status == restores_core.RestoreStatusFailed {
|
||||
failMsg := "unknown error"
|
||||
if restore.FailMessage != nil {
|
||||
failMsg = *restore.FailMessage
|
||||
@@ -579,7 +577,7 @@ func connectToMysqlContainer(version tools.MysqlVersion, port string) (*MysqlCon
|
||||
dbName := "testdb"
|
||||
password := "rootpassword"
|
||||
username := "root"
|
||||
host := "127.0.0.1"
|
||||
host := config.GetEnv().TestLocalhost
|
||||
|
||||
portInt, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
|
||||
@@ -23,8 +23,7 @@ import (
|
||||
"databasus-backend/internal/features/databases"
|
||||
pgtypes "databasus-backend/internal/features/databases/databases/postgresql"
|
||||
"databasus-backend/internal/features/restores"
|
||||
restores_enums "databasus-backend/internal/features/restores/enums"
|
||||
restores_models "databasus-backend/internal/features/restores/models"
|
||||
restores_core "databasus-backend/internal/features/restores/core"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
@@ -125,6 +124,10 @@ func Test_BackupAndRestorePostgresqlWithEncryption_RestoreIsSuccessful(t *testin
|
||||
}
|
||||
|
||||
func Test_BackupAndRestoreSupabase_PublicSchemaOnly_RestoreIsSuccessful(t *testing.T) {
|
||||
if config.GetEnv().IsSkipExternalResourcesTests {
|
||||
t.Skip("Skipping Supabase test: IS_SKIP_EXTERNAL_RESOURCES_TESTS is true")
|
||||
}
|
||||
|
||||
env := config.GetEnv()
|
||||
|
||||
if env.TestSupabaseHost == "" {
|
||||
@@ -212,7 +215,7 @@ func Test_BackupAndRestoreSupabase_PublicSchemaOnly_RestoreIsSuccessful(t *testi
|
||||
)
|
||||
|
||||
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var countAfterRestore int
|
||||
err = supabaseDB.Get(
|
||||
@@ -439,7 +442,7 @@ func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string, cp
|
||||
)
|
||||
|
||||
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var tableExists bool
|
||||
err = newDB.Get(
|
||||
@@ -555,7 +558,7 @@ func testSchemaSelectionAllSchemasForVersion(t *testing.T, pgVersion string, por
|
||||
)
|
||||
|
||||
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var publicTableExists bool
|
||||
err = newDB.Get(&publicTableExists, `
|
||||
@@ -689,7 +692,7 @@ func testBackupRestoreWithExcludeExtensionsForVersion(t *testing.T, pgVersion st
|
||||
)
|
||||
|
||||
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
// Verify the table was restored
|
||||
var tableExists bool
|
||||
@@ -829,7 +832,7 @@ func testBackupRestoreWithoutExcludeExtensionsForVersion(
|
||||
)
|
||||
|
||||
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
// Verify the extension was recovered
|
||||
var extensionExists bool
|
||||
@@ -956,7 +959,7 @@ func testBackupRestoreWithReadOnlyUserForVersion(t *testing.T, pgVersion string,
|
||||
)
|
||||
|
||||
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var tableExists bool
|
||||
err = newDB.Get(
|
||||
@@ -1076,7 +1079,7 @@ func testSchemaSelectionOnlySpecifiedSchemasForVersion(
|
||||
)
|
||||
|
||||
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var publicTableExists bool
|
||||
err = newDB.Get(&publicTableExists, `
|
||||
@@ -1190,7 +1193,7 @@ func testBackupRestoreWithEncryptionForVersion(t *testing.T, pgVersion string, p
|
||||
)
|
||||
|
||||
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var tableExists bool
|
||||
err = newDB.Get(
|
||||
@@ -1286,7 +1289,7 @@ func waitForRestoreCompletion(
|
||||
backupID uuid.UUID,
|
||||
token string,
|
||||
timeout time.Duration,
|
||||
) *restores_models.Restore {
|
||||
) *restores_core.Restore {
|
||||
startTime := time.Now()
|
||||
pollInterval := 500 * time.Millisecond
|
||||
|
||||
@@ -1295,7 +1298,7 @@ func waitForRestoreCompletion(
|
||||
t.Fatalf("Timeout waiting for restore completion after %v", timeout)
|
||||
}
|
||||
|
||||
var restores []*restores_models.Restore
|
||||
var restores []*restores_core.Restore
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
@@ -1306,10 +1309,10 @@ func waitForRestoreCompletion(
|
||||
)
|
||||
|
||||
for _, restore := range restores {
|
||||
if restore.Status == restores_enums.RestoreStatusCompleted {
|
||||
if restore.Status == restores_core.RestoreStatusCompleted {
|
||||
return restore
|
||||
}
|
||||
if restore.Status == restores_enums.RestoreStatusFailed {
|
||||
if restore.Status == restores_core.RestoreStatusFailed {
|
||||
failMsg := "unknown error"
|
||||
if restore.FailMessage != nil {
|
||||
failMsg = *restore.FailMessage
|
||||
@@ -1476,7 +1479,7 @@ func createRestoreWithCpuCountViaAPI(
|
||||
cpuCount int,
|
||||
token string,
|
||||
) {
|
||||
request := restores.RestoreBackupRequest{
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
PostgresqlDatabase: &pgtypes.PostgresqlDatabase{
|
||||
Host: host,
|
||||
Port: port,
|
||||
@@ -1509,7 +1512,7 @@ func createRestoreWithOptionsViaAPI(
|
||||
isExcludeExtensions bool,
|
||||
token string,
|
||||
) {
|
||||
request := restores.RestoreBackupRequest{
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
PostgresqlDatabase: &pgtypes.PostgresqlDatabase{
|
||||
Host: host,
|
||||
Port: port,
|
||||
@@ -1647,7 +1650,7 @@ func createSupabaseRestoreViaAPI(
|
||||
database string,
|
||||
token string,
|
||||
) {
|
||||
request := restores.RestoreBackupRequest{
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
PostgresqlDatabase: &pgtypes.PostgresqlDatabase{
|
||||
Host: host,
|
||||
Port: port,
|
||||
@@ -1755,7 +1758,7 @@ func connectToPostgresContainer(version string, port string) (*PostgresContainer
|
||||
dbName := "testdb"
|
||||
password := "testpassword"
|
||||
username := "testuser"
|
||||
host := "localhost"
|
||||
host := config.GetEnv().TestLocalhost
|
||||
|
||||
portInt, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
"databasus-backend/internal/features/restores/restoring"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
)
|
||||
|
||||
@@ -12,11 +13,15 @@ func TestMain(m *testing.M) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
backuperNode := backuping.CreateTestBackuperNode()
|
||||
cancel := backuping.StartBackuperNodeForTest(&testing.T{}, backuperNode)
|
||||
cancelBackup := backuping.StartBackuperNodeForTest(&testing.T{}, backuperNode)
|
||||
|
||||
restorerNode := restoring.CreateTestRestorerNode()
|
||||
cancelRestore := restoring.StartRestorerNodeForTest(&testing.T{}, restorerNode)
|
||||
|
||||
exitCode := m.Run()
|
||||
|
||||
backuping.StopBackuperNodeForTest(&testing.T{}, cancel, backuperNode)
|
||||
backuping.StopBackuperNodeForTest(&testing.T{}, cancelBackup, backuperNode)
|
||||
restoring.StopRestorerNodeForTest(&testing.T{}, cancelRestore, restorerNode)
|
||||
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
42
backend/internal/util/cache/cache_test.go
vendored
42
backend/internal/util/cache/cache_test.go
vendored
@@ -1,7 +1,9 @@
|
||||
package cache_utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
@@ -49,3 +51,43 @@ func Test_ClearAllCache_AfterClear_CacheIsEmpty(t *testing.T) {
|
||||
assert.Nil(t, retrieved, "Key %s should be deleted after clearing", tk.prefix+tk.key)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SetWithExpiration_SetsCorrectTTL(t *testing.T) {
|
||||
client := getCache()
|
||||
|
||||
// Create a cache utility
|
||||
testPrefix := "test:ttl:"
|
||||
cacheUtil := NewCacheUtil[string](client, testPrefix)
|
||||
|
||||
// Set a value with 1-hour expiration
|
||||
testKey := "key1"
|
||||
testValue := "test value"
|
||||
oneHour := 1 * time.Hour
|
||||
|
||||
cacheUtil.SetWithExpiration(testKey, &testValue, oneHour)
|
||||
|
||||
// Verify the value was set
|
||||
retrieved := cacheUtil.Get(testKey)
|
||||
assert.NotNil(t, retrieved, "Value should be stored")
|
||||
assert.Equal(t, testValue, *retrieved, "Retrieved value should match")
|
||||
|
||||
// Check the TTL using Valkey TTL command
|
||||
ctx, cancel := context.WithTimeout(context.Background(), DefaultCacheTimeout)
|
||||
defer cancel()
|
||||
|
||||
fullKey := testPrefix + testKey
|
||||
ttlResult := client.Do(ctx, client.B().Ttl().Key(fullKey).Build())
|
||||
assert.NoError(t, ttlResult.Error(), "TTL command should not error")
|
||||
|
||||
ttlSeconds, err := ttlResult.AsInt64()
|
||||
assert.NoError(t, err, "TTL should be retrievable as int64")
|
||||
|
||||
// TTL should be approximately 1 hour (3600 seconds)
|
||||
// Allow for a small margin (within 10 seconds of 3600)
|
||||
expectedTTL := int64(3600)
|
||||
assert.GreaterOrEqual(t, ttlSeconds, expectedTTL-10, "TTL should be close to 1 hour")
|
||||
assert.LessOrEqual(t, ttlSeconds, expectedTTL, "TTL should not exceed 1 hour")
|
||||
|
||||
// Clean up
|
||||
cacheUtil.Invalidate(testKey)
|
||||
}
|
||||
|
||||
37
backend/internal/util/cache/utils.go
vendored
37
backend/internal/util/cache/utils.go
vendored
@@ -67,6 +67,43 @@ func (c *CacheUtil[T]) Set(key string, item *T) {
|
||||
c.client.Do(ctx, c.client.B().Set().Key(fullKey).Value(string(data)).Ex(c.expiry).Build())
|
||||
}
|
||||
|
||||
func (c *CacheUtil[T]) SetWithExpiration(key string, item *T, expiry time.Duration) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
|
||||
defer cancel()
|
||||
|
||||
data, err := json.Marshal(item)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
fullKey := c.prefix + key
|
||||
c.client.Do(ctx, c.client.B().Set().Key(fullKey).Value(string(data)).Ex(expiry).Build())
|
||||
}
|
||||
|
||||
func (c *CacheUtil[T]) GetAndDelete(key string) *T {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
|
||||
defer cancel()
|
||||
|
||||
fullKey := c.prefix + key
|
||||
result := c.client.Do(ctx, c.client.B().Getdel().Key(fullKey).Build())
|
||||
|
||||
if result.Error() != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
data, err := result.AsBytes()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var item T
|
||||
if err := json.Unmarshal(data, &item); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &item
|
||||
}
|
||||
|
||||
func (c *CacheUtil[T]) Invalidate(key string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
|
||||
defer cancel()
|
||||
|
||||
@@ -46,4 +46,8 @@ export const restoreApi = {
|
||||
requestOptions,
|
||||
);
|
||||
},
|
||||
|
||||
async cancelRestore(restoreId: string) {
|
||||
return apiHelper.fetchPostRaw(`${getApplicationServer()}/api/v1/restores/cancel/${restoreId}`);
|
||||
},
|
||||
};
|
||||
|
||||
@@ -2,4 +2,5 @@ export enum RestoreStatus {
|
||||
IN_PROGRESS = 'IN_PROGRESS',
|
||||
COMPLETED = 'COMPLETED',
|
||||
FAILED = 'FAILED',
|
||||
CANCELED = 'CANCELED',
|
||||
}
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
import { CopyOutlined, ExclamationCircleOutlined, SyncOutlined } from '@ant-design/icons';
|
||||
import { CheckCircleOutlined } from '@ant-design/icons';
|
||||
import {
|
||||
CheckCircleOutlined,
|
||||
CloseCircleOutlined,
|
||||
CopyOutlined,
|
||||
ExclamationCircleOutlined,
|
||||
SyncOutlined,
|
||||
} from '@ant-design/icons';
|
||||
import { App, Button, Modal, Spin, Tooltip } from 'antd';
|
||||
import dayjs from 'dayjs';
|
||||
import { useEffect, useRef, useState } from 'react';
|
||||
@@ -8,6 +13,7 @@ import type { Backup } from '../../../entity/backups';
|
||||
import { type Database, DatabaseType } from '../../../entity/databases';
|
||||
import { type Restore, RestoreStatus, restoreApi } from '../../../entity/restores';
|
||||
import { getUserTimeFormat } from '../../../shared/time';
|
||||
import { ConfirmationComponent } from '../../../shared/ui';
|
||||
import { EditDatabaseSpecificDataComponent } from '../../databases/ui/edit/EditDatabaseSpecificDataComponent';
|
||||
|
||||
interface Props {
|
||||
@@ -70,6 +76,10 @@ export const RestoresComponent = ({ database, backup }: Props) => {
|
||||
|
||||
const [isShowRestore, setIsShowRestore] = useState(false);
|
||||
|
||||
const [cancellingRestoreId, setCancellingRestoreId] = useState<string | undefined>();
|
||||
const [showCancelConfirmation, setShowCancelConfirmation] = useState(false);
|
||||
const [restoreToCancelId, setRestoreToCancelId] = useState<string | undefined>();
|
||||
|
||||
const isReloadInProgress = useRef(false);
|
||||
|
||||
const loadRestores = async () => {
|
||||
@@ -103,6 +113,18 @@ export const RestoresComponent = ({ database, backup }: Props) => {
|
||||
}
|
||||
};
|
||||
|
||||
const cancelRestore = async (restoreId: string) => {
|
||||
setCancellingRestoreId(restoreId);
|
||||
try {
|
||||
await restoreApi.cancelRestore(restoreId);
|
||||
await loadRestores();
|
||||
} catch (e) {
|
||||
alert((e as Error).message);
|
||||
} finally {
|
||||
setCancellingRestoreId(undefined);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
setIsLoading(true);
|
||||
loadRestores().finally(() => setIsLoading(false));
|
||||
@@ -190,40 +212,77 @@ export const RestoresComponent = ({ database, backup }: Props) => {
|
||||
|
||||
return (
|
||||
<div key={restore.id} className="mb-1 rounded border border-gray-200 p-3 text-sm">
|
||||
<div className="mb-1 flex">
|
||||
<div className="w-[75px] min-w-[75px]">Status</div>
|
||||
<div className="mb-1 flex items-center justify-between">
|
||||
<div className="flex flex-1">
|
||||
<div className="w-[75px] min-w-[75px]">Status</div>
|
||||
|
||||
{restore.status === RestoreStatus.FAILED && (
|
||||
<Tooltip title="Click to see error details">
|
||||
<div
|
||||
className="flex cursor-pointer items-center text-red-600 underline"
|
||||
onClick={() => setShowingRestoreError(restore)}
|
||||
>
|
||||
<ExclamationCircleOutlined
|
||||
{restore.status === RestoreStatus.FAILED && (
|
||||
<Tooltip title="Click to see error details">
|
||||
<div
|
||||
className="flex cursor-pointer items-center text-red-600 underline"
|
||||
onClick={() => setShowingRestoreError(restore)}
|
||||
>
|
||||
<ExclamationCircleOutlined
|
||||
className="mr-2"
|
||||
style={{ fontSize: 16, color: '#ff0000' }}
|
||||
/>
|
||||
|
||||
<div>Failed</div>
|
||||
</div>
|
||||
</Tooltip>
|
||||
)}
|
||||
|
||||
{restore.status === RestoreStatus.COMPLETED && (
|
||||
<div className="flex items-center">
|
||||
<CheckCircleOutlined
|
||||
className="mr-2"
|
||||
style={{ fontSize: 16, color: '#ff0000' }}
|
||||
style={{ fontSize: 16, color: '#008000' }}
|
||||
/>
|
||||
|
||||
<div>Failed</div>
|
||||
<div>Successful</div>
|
||||
</div>
|
||||
</Tooltip>
|
||||
)}
|
||||
)}
|
||||
|
||||
{restore.status === RestoreStatus.COMPLETED && (
|
||||
<div className="flex items-center">
|
||||
<CheckCircleOutlined
|
||||
className="mr-2"
|
||||
style={{ fontSize: 16, color: '#008000' }}
|
||||
/>
|
||||
{restore.status === RestoreStatus.CANCELED && (
|
||||
<div className="flex items-center text-gray-500">
|
||||
<CloseCircleOutlined
|
||||
className="mr-2"
|
||||
style={{ fontSize: 16, color: '#808080' }}
|
||||
/>
|
||||
|
||||
<div>Successful</div>
|
||||
</div>
|
||||
)}
|
||||
<div>Canceled</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{restore.status === RestoreStatus.IN_PROGRESS && (
|
||||
<div className="flex items-center font-bold text-blue-600">
|
||||
<SyncOutlined spin />
|
||||
<span className="ml-2">In progress</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{restore.status === RestoreStatus.IN_PROGRESS && (
|
||||
<div className="flex items-center font-bold text-blue-600">
|
||||
<SyncOutlined spin />
|
||||
<span className="ml-2">In progress</span>
|
||||
<div className="ml-2">
|
||||
{cancellingRestoreId === restore.id ? (
|
||||
<SyncOutlined spin style={{ fontSize: 16 }} />
|
||||
) : (
|
||||
<Tooltip title="Cancel restore">
|
||||
<CloseCircleOutlined
|
||||
className="cursor-pointer"
|
||||
onClick={() => {
|
||||
if (cancellingRestoreId) return;
|
||||
setRestoreToCancelId(restore.id);
|
||||
setShowCancelConfirmation(true);
|
||||
}}
|
||||
style={{
|
||||
color: '#ff0000',
|
||||
fontSize: 16,
|
||||
opacity: cancellingRestoreId ? 0.2 : 1,
|
||||
}}
|
||||
/>
|
||||
</Tooltip>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
@@ -289,6 +348,25 @@ export const RestoresComponent = ({ database, backup }: Props) => {
|
||||
</div>
|
||||
</Modal>
|
||||
)}
|
||||
|
||||
{showCancelConfirmation && (
|
||||
<ConfirmationComponent
|
||||
onConfirm={() => {
|
||||
setShowCancelConfirmation(false);
|
||||
if (restoreToCancelId) {
|
||||
cancelRestore(restoreToCancelId);
|
||||
}
|
||||
setRestoreToCancelId(undefined);
|
||||
}}
|
||||
onDecline={() => {
|
||||
setShowCancelConfirmation(false);
|
||||
setRestoreToCancelId(undefined);
|
||||
}}
|
||||
description="<strong>⚠️ Warning:</strong> Cancelling this restore will likely leave your database in a corrupted or incomplete state. You will need to recreate the database before attempting another restore.<br/><br/>Are you sure you want to cancel?"
|
||||
actionText="Yes, cancel restore"
|
||||
actionButtonColor="red"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user