mirror of
https://github.com/databasus/databasus.git
synced 2026-04-06 00:32:03 +02:00
Compare commits
36 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4602dc3f88 | ||
|
|
cbbfc5ea8f | ||
|
|
dd1072e230 | ||
|
|
a495e5317a | ||
|
|
7eed647038 | ||
|
|
6973241e25 | ||
|
|
ab181f5b81 | ||
|
|
b60a0cc170 | ||
|
|
f319a497b3 | ||
|
|
bc870b3f8e | ||
|
|
15383c59eb | ||
|
|
d14c223a65 | ||
|
|
2c0a294027 | ||
|
|
5d851d73bd | ||
|
|
699913c251 | ||
|
|
a2e3f30a6d | ||
|
|
80f1174ecd | ||
|
|
a47f8d5e2c | ||
|
|
54b9e67656 | ||
|
|
3782846872 | ||
|
|
245a81897f | ||
|
|
5cbc0773b6 | ||
|
|
997fc01442 | ||
|
|
6d0ae32d0c | ||
|
|
011985d723 | ||
|
|
d677ee61de | ||
|
|
c6b8f6e87a | ||
|
|
2bb5f93d00 | ||
|
|
b91c150300 | ||
|
|
12b119ce40 | ||
|
|
7c6f0ab4ba | ||
|
|
6d2db4b298 | ||
|
|
6397423298 | ||
|
|
3470aae8e3 | ||
|
|
184fbcdb2c | ||
|
|
2d897dd722 |
15
.github/workflows/ci-release.yml
vendored
15
.github/workflows/ci-release.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.24.4"
|
||||
go-version: "1.24.9"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
@@ -134,7 +134,7 @@ jobs:
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.24.4"
|
||||
go-version: "1.24.9"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
@@ -221,6 +221,12 @@ jobs:
|
||||
TEST_MONGODB_60_PORT=27060
|
||||
TEST_MONGODB_70_PORT=27070
|
||||
TEST_MONGODB_82_PORT=27082
|
||||
# Valkey (cache)
|
||||
VALKEY_HOST=localhost
|
||||
VALKEY_PORT=6379
|
||||
VALKEY_USERNAME=
|
||||
VALKEY_PASSWORD=
|
||||
VALKEY_IS_SSL=false
|
||||
EOF
|
||||
|
||||
- name: Start test containers
|
||||
@@ -233,6 +239,11 @@ jobs:
|
||||
# Wait for main dev database
|
||||
timeout 60 bash -c 'until docker exec dev-db pg_isready -h localhost -p 5437 -U postgres; do sleep 2; done'
|
||||
|
||||
# Wait for Valkey (cache)
|
||||
echo "Waiting for Valkey..."
|
||||
timeout 60 bash -c 'until docker exec dev-valkey valkey-cli ping 2>/dev/null | grep -q PONG; do sleep 2; done'
|
||||
echo "Valkey is ready!"
|
||||
|
||||
# Wait for test databases
|
||||
timeout 60 bash -c 'until nc -z localhost 5000; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5001; do sleep 2; done'
|
||||
|
||||
@@ -27,3 +27,10 @@ repos:
|
||||
language: system
|
||||
files: ^backend/.*\.go$
|
||||
pass_filenames: false
|
||||
|
||||
- id: backend-go-mod-tidy
|
||||
name: Backend Go Mod Tidy
|
||||
entry: bash -c "cd backend && go mod tidy"
|
||||
language: system
|
||||
files: ^backend/.*\.go$
|
||||
pass_filenames: false
|
||||
|
||||
38
Dockerfile
38
Dockerfile
@@ -22,7 +22,7 @@ RUN npm run build
|
||||
|
||||
# ========= BUILD BACKEND =========
|
||||
# Backend build stage
|
||||
FROM --platform=$BUILDPLATFORM golang:1.24.4 AS backend-build
|
||||
FROM --platform=$BUILDPLATFORM golang:1.24.9 AS backend-build
|
||||
|
||||
# Make TARGET args available early so tools built here match the final image arch
|
||||
ARG TARGETOS
|
||||
@@ -123,6 +123,15 @@ RUN wget -qO- https://www.postgresql.org/media/keys/ACCC4CF8.asc | apt-key add -
|
||||
apt-get install -y --no-install-recommends postgresql-17 && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Valkey server from debian repository
|
||||
# Valkey is only accessible internally (localhost) - not exposed outside container
|
||||
RUN wget -O /usr/share/keyrings/greensec.github.io-valkey-debian.key https://greensec.github.io/valkey-debian/public.key && \
|
||||
echo "deb [signed-by=/usr/share/keyrings/greensec.github.io-valkey-debian.key] https://greensec.github.io/valkey-debian/repo $(lsb_release -cs) main" \
|
||||
> /etc/apt/sources.list.d/valkey-debian.list && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends valkey && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# ========= Install rclone =========
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends rclone && \
|
||||
@@ -245,7 +254,34 @@ PG_BIN="/usr/lib/postgresql/17/bin"
|
||||
# Ensure proper ownership of data directory
|
||||
echo "Setting up data directory permissions..."
|
||||
mkdir -p /databasus-data/pgdata
|
||||
mkdir -p /databasus-data/temp
|
||||
mkdir -p /databasus-data/backups
|
||||
chown -R postgres:postgres /databasus-data
|
||||
chmod 700 /databasus-data/temp
|
||||
|
||||
# ========= Start Valkey (internal cache) =========
|
||||
echo "Configuring Valkey cache..."
|
||||
cat > /tmp/valkey.conf << 'VALKEY_CONFIG'
|
||||
port 6379
|
||||
bind 127.0.0.1
|
||||
protected-mode yes
|
||||
save ""
|
||||
maxmemory 256mb
|
||||
maxmemory-policy allkeys-lru
|
||||
VALKEY_CONFIG
|
||||
|
||||
echo "Starting Valkey..."
|
||||
valkey-server /tmp/valkey.conf &
|
||||
VALKEY_PID=\$!
|
||||
|
||||
echo "Waiting for Valkey to be ready..."
|
||||
for i in {1..30}; do
|
||||
if valkey-cli ping >/dev/null 2>&1; then
|
||||
echo "Valkey is ready!"
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
# Initialize PostgreSQL if not already initialized
|
||||
if [ ! -s "/databasus-data/pgdata/PG_VERSION" ]; then
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
<img src="assets/logo.svg" alt="Databasus Logo" width="250"/>
|
||||
|
||||
<h3>Backup tool for PostgreSQL, MySQL and MongoDB</h3>
|
||||
<p>Databasus is a free, open source and self-hosted tool to backup databases. Make backups with different storages (S3, Google Drive, FTP, etc.) and notifications about progress (Slack, Discord, Telegram, etc.). Previously known as Postgresus (see migration guide).</p>
|
||||
<p>Databasus is a free, open source and self-hosted tool to backup databases (with focus on PostgreSQL). Make backups with different storages (S3, Google Drive, FTP, etc.) and notifications about progress (Slack, Discord, Telegram, etc.). Previously known as Postgresus (see migration guide).</p>
|
||||
|
||||
<!-- Badges -->
|
||||
[](https://www.postgresql.org/)
|
||||
|
||||
@@ -11,6 +11,12 @@ DATABASE_URL=postgres://postgres:Q1234567@dev-db:5437/databasus?sslmode=disable
|
||||
GOOSE_DRIVER=postgres
|
||||
GOOSE_DBSTRING=postgres://postgres:Q1234567@dev-db:5437/databasus?sslmode=disable
|
||||
GOOSE_MIGRATION_DIR=./migrations
|
||||
# valkey
|
||||
VALKEY_HOST=127.0.0.1
|
||||
VALKEY_PORT=6379
|
||||
VALKEY_USERNAME=
|
||||
VALKEY_PASSWORD=
|
||||
VALKEY_IS_SSL=false
|
||||
# testing
|
||||
# to get Google Drive env variables: add storage in UI and copy data from added storage here
|
||||
TEST_GOOGLE_DRIVE_CLIENT_ID=
|
||||
|
||||
@@ -10,4 +10,10 @@ DATABASE_URL=postgres://postgres:Q1234567@localhost:5437/databasus?sslmode=disab
|
||||
# migrations
|
||||
GOOSE_DRIVER=postgres
|
||||
GOOSE_DBSTRING=postgres://postgres:Q1234567@localhost:5437/databasus?sslmode=disable
|
||||
GOOSE_MIGRATION_DIR=./migrations
|
||||
GOOSE_MIGRATION_DIR=./migrations
|
||||
# valkey
|
||||
VALKEY_HOST=127.0.0.1
|
||||
VALKEY_PORT=6379
|
||||
VALKEY_USERNAME=
|
||||
VALKEY_PASSWORD=
|
||||
VALKEY_IS_SSL=false
|
||||
3
backend/.gitignore
vendored
3
backend/.gitignore
vendored
@@ -17,4 +17,5 @@ ui/build/*
|
||||
pgdata-for-restore/
|
||||
temp/
|
||||
cmd.exe
|
||||
temp/
|
||||
temp/
|
||||
valkey-data/
|
||||
@@ -2,10 +2,10 @@ run:
|
||||
go run cmd/main.go
|
||||
|
||||
test:
|
||||
go test -p=1 -count=1 -failfast -timeout 10m ./internal/...
|
||||
go test -p=1 -count=1 -failfast -timeout 15m ./internal/...
|
||||
|
||||
lint:
|
||||
golangci-lint fmt && golangci-lint run
|
||||
golangci-lint fmt ./cmd/... ./internal/... && golangci-lint run ./cmd/... ./internal/...
|
||||
|
||||
migration-create:
|
||||
goose create $(name) sql
|
||||
|
||||
@@ -15,6 +15,9 @@ import (
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/disk"
|
||||
@@ -29,6 +32,7 @@ import (
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
env_utils "databasus-backend/internal/util/env"
|
||||
files_utils "databasus-backend/internal/util/files"
|
||||
"databasus-backend/internal/util/logger"
|
||||
@@ -52,7 +56,21 @@ import (
|
||||
func main() {
|
||||
log := logger.GetLogger()
|
||||
|
||||
runMigrations(log)
|
||||
cache_utils.TestCacheConnection()
|
||||
|
||||
if config.GetEnv().IsPrimaryNode {
|
||||
err := cache_utils.ClearAllCache()
|
||||
if err != nil {
|
||||
log.Error("Failed to clear cache", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
if config.GetEnv().IsPrimaryNode {
|
||||
runMigrations(log)
|
||||
} else {
|
||||
log.Info("Skipping migrations (IS_PRIMARY_NODE is false)")
|
||||
}
|
||||
|
||||
// create directories that used for backups and restore
|
||||
err := files_utils.EnsureDirectories([]string{
|
||||
@@ -96,7 +114,9 @@ func main() {
|
||||
enableCors(ginApp)
|
||||
setUpRoutes(ginApp)
|
||||
setUpDependencies()
|
||||
|
||||
runBackgroundTasks(log)
|
||||
|
||||
mountFrontend(ginApp)
|
||||
|
||||
startServerWithGracefulShutdown(log, ginApp)
|
||||
@@ -219,35 +239,64 @@ func setUpDependencies() {
|
||||
notifiers.SetupDependencies()
|
||||
storages.SetupDependencies()
|
||||
backups_config.SetupDependencies()
|
||||
backups_cancellation.SetupDependencies()
|
||||
}
|
||||
|
||||
func runBackgroundTasks(log *slog.Logger) {
|
||||
log.Info("Preparing to run background tasks...")
|
||||
|
||||
err := files_utils.CleanFolder(config.GetEnv().TempFolder)
|
||||
if err != nil {
|
||||
log.Error("Failed to clean temp folder", "error", err)
|
||||
// Create context that will be cancelled on shutdown
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Set up signal handling for graceful shutdown
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, os.Interrupt, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-quit
|
||||
log.Info("Shutdown signal received, cancelling all background tasks")
|
||||
cancel()
|
||||
}()
|
||||
|
||||
if config.GetEnv().IsPrimaryNode {
|
||||
log.Info("Starting primary node background tasks...")
|
||||
|
||||
err := files_utils.CleanFolder(config.GetEnv().TempFolder)
|
||||
if err != nil {
|
||||
log.Error("Failed to clean temp folder", "error", err)
|
||||
}
|
||||
|
||||
go runWithPanicLogging(log, "backup background service", func() {
|
||||
backuping.GetBackupsScheduler().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "restore background service", func() {
|
||||
restores.GetRestoreBackgroundService().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "healthcheck attempt background service", func() {
|
||||
healthcheck_attempt.GetHealthcheckAttemptBackgroundService().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "audit log cleanup background service", func() {
|
||||
audit_logs.GetAuditLogBackgroundService().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "download token cleanup background service", func() {
|
||||
backups_download.GetDownloadTokenBackgroundService().Run(ctx)
|
||||
})
|
||||
} else {
|
||||
log.Info("Skipping primary node tasks as not primary node")
|
||||
}
|
||||
|
||||
go runWithPanicLogging(log, "backup background service", func() {
|
||||
backups.GetBackupBackgroundService().Run()
|
||||
})
|
||||
if config.GetEnv().IsBackupNode {
|
||||
log.Info("Starting backup node background tasks...")
|
||||
|
||||
go runWithPanicLogging(log, "restore background service", func() {
|
||||
restores.GetRestoreBackgroundService().Run()
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "healthcheck attempt background service", func() {
|
||||
healthcheck_attempt.GetHealthcheckAttemptBackgroundService().Run()
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "audit log cleanup background service", func() {
|
||||
audit_logs.GetAuditLogBackgroundService().Run()
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "download token cleanup background service", func() {
|
||||
backups.GetDownloadTokenBackgroundService().Run()
|
||||
})
|
||||
go runWithPanicLogging(log, "backup node", func() {
|
||||
backuping.GetBackuperNode().Run(ctx)
|
||||
})
|
||||
} else {
|
||||
log.Info("Skipping backup node tasks as not backup node")
|
||||
}
|
||||
}
|
||||
|
||||
func runWithPanicLogging(log *slog.Logger, serviceName string, fn func()) {
|
||||
@@ -290,16 +339,13 @@ func generateSwaggerDocs(log *slog.Logger) {
|
||||
func runMigrations(log *slog.Logger) {
|
||||
log.Info("Running database migrations...")
|
||||
|
||||
cmd := exec.Command("goose", "up")
|
||||
cmd := exec.Command("goose", "-dir", "./migrations", "up")
|
||||
cmd.Env = append(
|
||||
os.Environ(),
|
||||
"GOOSE_DRIVER=postgres",
|
||||
"GOOSE_DBSTRING="+config.GetEnv().DatabaseDsn,
|
||||
)
|
||||
|
||||
// Set the working directory to where migrations are located
|
||||
cmd.Dir = "./migrations"
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
log.Error("Failed to run migrations", "error", err, "output", string(output))
|
||||
|
||||
@@ -19,6 +19,21 @@ services:
|
||||
command: -p 5437
|
||||
shm_size: 10gb
|
||||
|
||||
# Valkey for caching
|
||||
dev-valkey:
|
||||
image: valkey/valkey:9.0.1-alpine
|
||||
ports:
|
||||
- "${VALKEY_PORT:-6379}:6379"
|
||||
volumes:
|
||||
- ./valkey-data:/data
|
||||
container_name: dev-valkey
|
||||
healthcheck:
|
||||
test: ["CMD", "valkey-cli", "ping"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
start_period: 20s
|
||||
|
||||
# Test MinIO container
|
||||
test-minio:
|
||||
image: minio/minio:latest
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
module databasus-backend
|
||||
|
||||
go 1.24.4
|
||||
go 1.24.9
|
||||
|
||||
require (
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0
|
||||
@@ -25,9 +25,9 @@ require (
|
||||
github.com/swaggo/files v1.0.1
|
||||
github.com/swaggo/gin-swagger v1.6.0
|
||||
github.com/swaggo/swag v1.16.4
|
||||
github.com/valkey-io/valkey-go v1.0.70
|
||||
go.mongodb.org/mongo-driver v1.17.6
|
||||
golang.org/x/crypto v0.46.0
|
||||
golang.org/x/time v0.14.0
|
||||
gorm.io/driver/postgres v1.5.11
|
||||
gorm.io/gorm v1.26.1
|
||||
)
|
||||
@@ -185,6 +185,7 @@ require (
|
||||
go.yaml.in/yaml/v2 v2.4.3 // indirect
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||
golang.org/x/term v0.38.0 // indirect
|
||||
golang.org/x/time v0.14.0 // indirect
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
|
||||
gopkg.in/validator.v2 v2.0.1 // indirect
|
||||
moul.io/http2curl/v2 v2.3.0 // indirect
|
||||
@@ -269,7 +270,7 @@ require (
|
||||
go.opentelemetry.io/otel/metric v1.38.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.38.0 // indirect
|
||||
golang.org/x/arch v0.17.0 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/net v0.48.0 // indirect
|
||||
golang.org/x/oauth2 v0.33.0
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
golang.org/x/sys v0.39.0 // indirect
|
||||
|
||||
@@ -539,8 +539,8 @@ github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
||||
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
|
||||
github.com/onsi/ginkgo/v2 v2.17.3 h1:oJcvKpIb7/8uLpDDtnQuf18xVnwKp8DTD7DQ6gTd/MU=
|
||||
github.com/onsi/ginkgo/v2 v2.17.3/go.mod h1:nP2DPOQoNsQmsVyv5rDA8JkXQoCs6goXIvr/PRJ1eCc=
|
||||
github.com/onsi/gomega v1.37.0 h1:CdEG8g0S133B4OswTDC/5XPSzE1OeP29QOioj2PID2Y=
|
||||
github.com/onsi/gomega v1.37.0/go.mod h1:8D9+Txp43QWKhM24yyOBEdpkzN8FvJyAwecBgsU4KU0=
|
||||
github.com/onsi/gomega v1.38.3 h1:eTX+W6dobAYfFeGC2PV6RwXRu/MyT+cQguijutvkpSM=
|
||||
github.com/onsi/gomega v1.38.3/go.mod h1:ZCU1pkQcXDO5Sl9/VVEGlDyp+zm0m1cmeG5TOzLgdh4=
|
||||
github.com/oracle/oci-go-sdk/v65 v65.104.0 h1:l9awEvzWvxmYhy/97A0hZ87pa7BncYXmcO/S8+rvgK0=
|
||||
github.com/oracle/oci-go-sdk/v65 v65.104.0/go.mod h1:oB8jFGVc/7/zJ+DbleE8MzGHjhs2ioCz5stRTdZdIcY=
|
||||
github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg=
|
||||
@@ -660,6 +660,8 @@ github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY=
|
||||
github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
|
||||
github.com/unknwon/goconfig v1.0.0 h1:rS7O+CmUdli1T+oDm7fYj1MwqNWtEJfNj+FqcUHML8U=
|
||||
github.com/unknwon/goconfig v1.0.0/go.mod h1:qu2ZQ/wcC/if2u32263HTVC39PeOQRSmidQk3DuDFQ8=
|
||||
github.com/valkey-io/valkey-go v1.0.70 h1:mjYNT8qiazxDAJ0QNQ8twWT/YFOkOoRd40ERV2mB49Y=
|
||||
github.com/valkey-io/valkey-go v1.0.70/go.mod h1:VGhZ6fs68Qrn2+OhH+6waZH27bjpgQOiLyUQyXuYK5k=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||
github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM=
|
||||
@@ -720,6 +722,8 @@ go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
|
||||
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
|
||||
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
|
||||
go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/arch v0.17.0 h1:4O3dfLzd+lQewptAHqjewQZQDyEdejz3VwgeYwkZneU=
|
||||
golang.org/x/arch v0.17.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
@@ -818,8 +822,8 @@ golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
|
||||
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/ilyakaznacheev/cleanenv"
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
@@ -29,6 +30,12 @@ type EnvVariables struct {
|
||||
MariadbInstallDir string `env:"MARIADB_INSTALL_DIR"`
|
||||
MongodbInstallDir string `env:"MONGODB_INSTALL_DIR"`
|
||||
|
||||
NodeID string
|
||||
IsManyNodesMode bool `env:"IS_MANY_NODES_MODE"`
|
||||
IsPrimaryNode bool `env:"IS_PRIMARY_NODE"`
|
||||
IsBackupNode bool `env:"IS_BACKUP_NODE"`
|
||||
NodeNetworkThroughputMBs int `env:"NODE_NETWORK_THROUGHPUT_MBPS"`
|
||||
|
||||
DataFolder string
|
||||
TempFolder string
|
||||
SecretKeyPath string
|
||||
@@ -79,6 +86,13 @@ type EnvVariables struct {
|
||||
TestMongodb70Port string `env:"TEST_MONGODB_70_PORT"`
|
||||
TestMongodb82Port string `env:"TEST_MONGODB_82_PORT"`
|
||||
|
||||
// Valkey
|
||||
ValkeyHost string `env:"VALKEY_HOST" required:"true"`
|
||||
ValkeyPort string `env:"VALKEY_PORT" required:"true"`
|
||||
ValkeyUsername string `env:"VALKEY_USERNAME"`
|
||||
ValkeyPassword string `env:"VALKEY_PASSWORD"`
|
||||
ValkeyIsSsl bool `env:"VALKEY_IS_SSL" required:"true"`
|
||||
|
||||
// oauth
|
||||
GitHubClientID string `env:"GITHUB_CLIENT_ID"`
|
||||
GitHubClientSecret string `env:"GITHUB_CLIENT_SECRET"`
|
||||
@@ -189,6 +203,26 @@ func loadEnvVariables() {
|
||||
env.MongodbInstallDir = filepath.Join(backendRoot, "tools", "mongodb")
|
||||
tools.VerifyMongodbInstallation(log, env.EnvMode, env.MongodbInstallDir)
|
||||
|
||||
env.NodeID = uuid.New().String()
|
||||
if env.NodeNetworkThroughputMBs == 0 {
|
||||
env.NodeNetworkThroughputMBs = 125 // 1 Gbit/s
|
||||
}
|
||||
|
||||
if !env.IsManyNodesMode {
|
||||
env.IsPrimaryNode = true
|
||||
env.IsBackupNode = true
|
||||
}
|
||||
|
||||
// Valkey
|
||||
if env.ValkeyHost == "" {
|
||||
log.Error("VALKEY_HOST is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
if env.ValkeyPort == "" {
|
||||
log.Error("VALKEY_PORT is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Store the data and temp folders one level below the root
|
||||
// (projectRoot/databasus-data -> /databasus-data)
|
||||
env.DataFolder = filepath.Join(filepath.Dir(backendRoot), "databasus-data", "backups")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
)
|
||||
@@ -11,23 +11,25 @@ type AuditLogBackgroundService struct {
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *AuditLogBackgroundService) Run() {
|
||||
func (s *AuditLogBackgroundService) Run(ctx context.Context) {
|
||||
s.logger.Info("Starting audit log cleanup background service")
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
if config.IsShouldShutdown() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.cleanOldAuditLogs(); err != nil {
|
||||
s.logger.Error("Failed to clean old audit logs", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.cleanOldAuditLogs(); err != nil {
|
||||
s.logger.Error("Failed to clean old audit logs", "error", err)
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,254 +0,0 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/storages"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/period"
|
||||
"log/slog"
|
||||
"time"
|
||||
)
|
||||
|
||||
type BackupBackgroundService struct {
|
||||
backupService *BackupService
|
||||
backupRepository *BackupRepository
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
storageService *storages.StorageService
|
||||
|
||||
lastBackupTime time.Time
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *BackupBackgroundService) Run() {
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
|
||||
if err := s.failBackupsInProgress(); err != nil {
|
||||
s.logger.Error("Failed to fail backups in progress", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
if config.IsShouldShutdown() {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.cleanOldBackups(); err != nil {
|
||||
s.logger.Error("Failed to clean old backups", "error", err)
|
||||
}
|
||||
|
||||
if err := s.runPendingBackups(); err != nil {
|
||||
s.logger.Error("Failed to run pending backups", "error", err)
|
||||
}
|
||||
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
time.Sleep(1 * time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupBackgroundService) IsBackupsWorkerRunning() bool {
|
||||
// if last backup time is more than 5 minutes ago, return false
|
||||
return s.lastBackupTime.After(time.Now().UTC().Add(-5 * time.Minute))
|
||||
}
|
||||
|
||||
func (s *BackupBackgroundService) failBackupsInProgress() error {
|
||||
backupsInProgress, err := s.backupRepository.FindByStatus(BackupStatusInProgress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backup := range backupsInProgress {
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(backup.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
failMessage := "Backup failed due to application restart"
|
||||
backup.FailMessage = &failMessage
|
||||
backup.Status = BackupStatusFailed
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
s.backupService.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupFailed,
|
||||
&failMessage,
|
||||
)
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupBackgroundService) cleanOldBackups() error {
|
||||
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
backupStorePeriod := backupConfig.StorePeriod
|
||||
|
||||
if backupStorePeriod == period.PeriodForever {
|
||||
continue
|
||||
}
|
||||
|
||||
storeDuration := backupStorePeriod.ToDuration()
|
||||
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
|
||||
|
||||
oldBackups, err := s.backupRepository.FindBackupsBeforeDate(
|
||||
backupConfig.DatabaseID,
|
||||
dateBeforeBackupsShouldBeDeleted,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to find old backups for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, backup := range oldBackups {
|
||||
storage, err := s.storageService.GetStorageByID(backup.StorageID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to get storage by ID",
|
||||
"storageId",
|
||||
backup.StorageID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
err = storage.DeleteFile(encryptor, backup.ID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to delete backup file", "backupId", backup.ID, "error", err)
|
||||
}
|
||||
|
||||
if err := s.backupRepository.DeleteByID(backup.ID); err != nil {
|
||||
s.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
s.logger.Info(
|
||||
"Deleted old backup",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupBackgroundService) runPendingBackups() error {
|
||||
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
if backupConfig.BackupInterval == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
lastBackup, err := s.backupRepository.FindLastByDatabaseID(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to get last backup for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
var lastBackupTime *time.Time
|
||||
if lastBackup != nil {
|
||||
lastBackupTime = &lastBackup.CreatedAt
|
||||
}
|
||||
|
||||
remainedBackupTryCount := s.GetRemainedBackupTryCount(lastBackup)
|
||||
|
||||
if backupConfig.BackupInterval.ShouldTriggerBackup(time.Now().UTC(), lastBackupTime) ||
|
||||
remainedBackupTryCount > 0 {
|
||||
s.logger.Info(
|
||||
"Triggering scheduled backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"intervalType",
|
||||
backupConfig.BackupInterval.Interval,
|
||||
)
|
||||
|
||||
go s.backupService.MakeBackup(backupConfig.DatabaseID, remainedBackupTryCount == 1)
|
||||
s.logger.Info(
|
||||
"Successfully triggered scheduled backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRemainedBackupTryCount returns the number of remaining backup tries for a given backup.
|
||||
// If the backup is not failed or the backup config does not allow retries, it returns 0.
|
||||
// If the backup is failed and the backup config allows retries, it returns the number of remaining tries.
|
||||
// If the backup is failed and the backup config does not allow retries, it returns 0.
|
||||
func (s *BackupBackgroundService) GetRemainedBackupTryCount(lastBackup *Backup) int {
|
||||
if lastBackup == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
if lastBackup.Status != BackupStatusFailed {
|
||||
return 0
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(lastBackup.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
if !backupConfig.IsRetryIfFailed {
|
||||
return 0
|
||||
}
|
||||
|
||||
maxFailedTriesCount := backupConfig.MaxFailedTriesCount
|
||||
|
||||
lastBackups, err := s.backupRepository.FindByDatabaseIDWithLimit(
|
||||
lastBackup.DatabaseID,
|
||||
maxFailedTriesCount,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to find last backups by database ID", "error", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
lastFailedBackups := make([]*Backup, 0)
|
||||
|
||||
for _, backup := range lastBackups {
|
||||
if backup.Status == BackupStatusFailed {
|
||||
lastFailedBackups = append(lastFailedBackups, backup)
|
||||
}
|
||||
}
|
||||
|
||||
return maxFailedTriesCount - len(lastFailedBackups)
|
||||
}
|
||||
@@ -1,389 +0,0 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/util/period"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Enable backups for the database
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add old backup
|
||||
backupRepository.Save(&Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusCompleted,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
|
||||
// Wait for backup to complete (runs in goroutine)
|
||||
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
|
||||
|
||||
// assertions
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 2)
|
||||
}
|
||||
|
||||
func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Enable backups for the database
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add recent backup (1 hour ago)
|
||||
backupRepository.Save(&Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusCompleted,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// assertions
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1) // Should still be 1 backup, no new backup created
|
||||
}
|
||||
|
||||
func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T) {
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Enable backups for the database with retries disabled
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
backupConfig.IsRetryIfFailed = false
|
||||
backupConfig.MaxFailedTriesCount = 0
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add failed backup
|
||||
failMessage := "backup failed"
|
||||
backupRepository.Save(&Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// assertions
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1) // Should still be 1 backup, no retry attempted
|
||||
}
|
||||
|
||||
func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Enable backups for the database with retries enabled
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
backupConfig.IsRetryIfFailed = true
|
||||
backupConfig.MaxFailedTriesCount = 3
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add failed backup
|
||||
failMessage := "backup failed"
|
||||
backupRepository.Save(&Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
|
||||
// Wait for backup to complete (runs in goroutine)
|
||||
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
|
||||
|
||||
// assertions
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 2) // Should have 2 backups, retry was attempted
|
||||
}
|
||||
|
||||
func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *testing.T) {
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Enable backups for the database with retries enabled
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
backupConfig.IsRetryIfFailed = true
|
||||
backupConfig.MaxFailedTriesCount = 3
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
failMessage := "backup failed"
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
backupRepository.Save(&Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
})
|
||||
}
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// assertions
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 3) // Should have 3 backups, not more than max
|
||||
}
|
||||
|
||||
func Test_MakeBackgroundBackupWhenBakupsDisabled_BackupSkipped(t *testing.T) {
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = false
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add old backup that would trigger new backup if enabled
|
||||
backupRepository.Save(&Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusCompleted,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1)
|
||||
}
|
||||
@@ -1,60 +0,0 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type BackupContextManager struct {
|
||||
mu sync.RWMutex
|
||||
cancelFuncs map[uuid.UUID]context.CancelFunc
|
||||
cancelledBackups map[uuid.UUID]bool
|
||||
}
|
||||
|
||||
func NewBackupContextManager() *BackupContextManager {
|
||||
return &BackupContextManager{
|
||||
cancelFuncs: make(map[uuid.UUID]context.CancelFunc),
|
||||
cancelledBackups: make(map[uuid.UUID]bool),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) RegisterBackup(backupID uuid.UUID, cancelFunc context.CancelFunc) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.cancelFuncs[backupID] = cancelFunc
|
||||
delete(m.cancelledBackups, backupID)
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) CancelBackup(backupID uuid.UUID) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.cancelledBackups[backupID] {
|
||||
return nil
|
||||
}
|
||||
|
||||
cancelFunc, exists := m.cancelFuncs[backupID]
|
||||
if exists {
|
||||
cancelFunc()
|
||||
delete(m.cancelFuncs, backupID)
|
||||
}
|
||||
|
||||
m.cancelledBackups[backupID] = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) IsCancelled(backupID uuid.UUID) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.cancelledBackups[backupID]
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) UnregisterBackup(backupID uuid.UUID) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.cancelFuncs, backupID)
|
||||
delete(m.cancelledBackups, backupID)
|
||||
}
|
||||
344
backend/internal/features/backups/backups/backuping/backuper.go
Normal file
344
backend/internal/features/backups/backups/backuping/backuper.go
Normal file
@@ -0,0 +1,344 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"databasus-backend/internal/config"
|
||||
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/storages"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type BackuperNode struct {
|
||||
databaseService *databases.DatabaseService
|
||||
fieldEncryptor util_encryption.FieldEncryptor
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
backupRepository *backups_core.BackupRepository
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
storageService *storages.StorageService
|
||||
notificationSender backups_core.NotificationSender
|
||||
backupCancelManager *backups_cancellation.BackupCancelManager
|
||||
nodesRegistry *BackupNodesRegistry
|
||||
logger *slog.Logger
|
||||
createBackupUseCase backups_core.CreateBackupUsecase
|
||||
nodeID uuid.UUID
|
||||
|
||||
lastHeartbeat time.Time
|
||||
}
|
||||
|
||||
func (n *BackuperNode) Run(ctx context.Context) {
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
|
||||
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
|
||||
|
||||
backupNode := BackupNode{
|
||||
ID: n.nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: time.Now().UTC(),
|
||||
}
|
||||
|
||||
if err := n.nodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
|
||||
n.logger.Error("Failed to register node in registry", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
|
||||
n.MakeBackup(backupID, isCallNotifier)
|
||||
if err := n.nodesRegistry.PublishBackupCompletion(n.nodeID.String(), backupID); err != nil {
|
||||
n.logger.Error(
|
||||
"Failed to publish backup completion",
|
||||
"error",
|
||||
err,
|
||||
"backupID",
|
||||
backupID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if err := n.nodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID.String(), backupHandler); err != nil {
|
||||
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := n.nodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
|
||||
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(15 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
|
||||
|
||||
if err := n.nodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
|
||||
n.logger.Error("Failed to unregister node from registry", "error", err)
|
||||
}
|
||||
|
||||
return
|
||||
case <-ticker.C:
|
||||
n.sendHeartbeat(&backupNode)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (n *BackuperNode) IsBackuperRunning() bool {
|
||||
return n.lastHeartbeat.After(time.Now().UTC().Add(-5 * time.Minute))
|
||||
}
|
||||
|
||||
func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
|
||||
backup, err := n.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get backup by ID", "backupId", backupID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
databaseID := backup.DatabaseID
|
||||
|
||||
database, err := n.databaseService.GetDatabaseByID(databaseID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get database by ID", "databaseId", databaseID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
backupConfig, err := n.backupConfigService.GetBackupConfigByDbId(databaseID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if backupConfig.StorageID == nil {
|
||||
n.logger.Error("Backup config storage ID is not defined")
|
||||
return
|
||||
}
|
||||
|
||||
storage, err := n.storageService.GetStorageByID(*backupConfig.StorageID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get storage by ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now().UTC()
|
||||
|
||||
backupProgressListener := func(
|
||||
completedMBs float64,
|
||||
) {
|
||||
backup.BackupSizeMb = completedMBs
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to update backup progress", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
n.backupCancelManager.RegisterBackup(backup.ID, cancel)
|
||||
defer n.backupCancelManager.UnregisterBackup(backup.ID)
|
||||
|
||||
backupMetadata, err := n.createBackupUseCase.Execute(
|
||||
ctx,
|
||||
backup.ID,
|
||||
backupConfig,
|
||||
database,
|
||||
storage,
|
||||
backupProgressListener,
|
||||
)
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
|
||||
// Check if backup was cancelled (not due to shutdown)
|
||||
isCancelled := strings.Contains(errMsg, "backup cancelled") ||
|
||||
strings.Contains(errMsg, "context canceled") ||
|
||||
errors.Is(err, context.Canceled)
|
||||
isShutdown := strings.Contains(errMsg, "shutdown")
|
||||
|
||||
if isCancelled && !isShutdown {
|
||||
backup.Status = backups_core.BackupStatusCanceled
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to save cancelled backup", "error", err)
|
||||
}
|
||||
|
||||
// Delete partial backup from storage
|
||||
storage, storageErr := n.storageService.GetStorageByID(backup.StorageID)
|
||||
if storageErr == nil {
|
||||
if deleteErr := storage.DeleteFile(n.fieldEncryptor, backup.ID); deleteErr != nil {
|
||||
n.logger.Error(
|
||||
"Failed to delete partial backup file",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
deleteErr,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
backup.FailMessage = &errMsg
|
||||
backup.Status = backups_core.BackupStatusFailed
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
if updateErr := n.databaseService.SetBackupError(databaseID, errMsg); updateErr != nil {
|
||||
n.logger.Error(
|
||||
"Failed to update database last backup time",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"error",
|
||||
updateErr,
|
||||
)
|
||||
}
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to save backup", "error", err)
|
||||
}
|
||||
|
||||
n.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupFailed,
|
||||
&errMsg,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
backup.Status = backups_core.BackupStatusCompleted
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
// Update backup with encryption metadata if provided
|
||||
if backupMetadata != nil {
|
||||
backup.EncryptionSalt = backupMetadata.EncryptionSalt
|
||||
backup.EncryptionIV = backupMetadata.EncryptionIV
|
||||
backup.Encryption = backupMetadata.Encryption
|
||||
}
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to save backup", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Update database last backup time
|
||||
now := time.Now().UTC()
|
||||
if updateErr := n.databaseService.SetLastBackupTime(databaseID, now); updateErr != nil {
|
||||
n.logger.Error(
|
||||
"Failed to update database last backup time",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"error",
|
||||
updateErr,
|
||||
)
|
||||
}
|
||||
|
||||
if backup.Status != backups_core.BackupStatusCompleted && !isCallNotifier {
|
||||
return
|
||||
}
|
||||
|
||||
n.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupSuccess,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
func (n *BackuperNode) SendBackupNotification(
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
backup *backups_core.Backup,
|
||||
notificationType backups_config.BackupNotificationType,
|
||||
errorMessage *string,
|
||||
) {
|
||||
database, err := n.databaseService.GetDatabaseByID(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
workspace, err := n.workspaceService.GetWorkspaceByID(*database.WorkspaceID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, notifier := range database.Notifiers {
|
||||
if !slices.Contains(
|
||||
backupConfig.SendNotificationsOn,
|
||||
notificationType,
|
||||
) {
|
||||
continue
|
||||
}
|
||||
|
||||
title := ""
|
||||
switch notificationType {
|
||||
case backups_config.NotificationBackupFailed:
|
||||
title = fmt.Sprintf(
|
||||
"❌ Backup failed for database \"%s\" (workspace \"%s\")",
|
||||
database.Name,
|
||||
workspace.Name,
|
||||
)
|
||||
case backups_config.NotificationBackupSuccess:
|
||||
title = fmt.Sprintf(
|
||||
"✅ Backup completed for database \"%s\" (workspace \"%s\")",
|
||||
database.Name,
|
||||
workspace.Name,
|
||||
)
|
||||
}
|
||||
|
||||
message := ""
|
||||
if errorMessage != nil {
|
||||
message = *errorMessage
|
||||
} else {
|
||||
// Format size conditionally
|
||||
var sizeStr string
|
||||
if backup.BackupSizeMb < 1024 {
|
||||
sizeStr = fmt.Sprintf("%.2f MB", backup.BackupSizeMb)
|
||||
} else {
|
||||
sizeGB := backup.BackupSizeMb / 1024
|
||||
sizeStr = fmt.Sprintf("%.2f GB", sizeGB)
|
||||
}
|
||||
|
||||
// Format duration as "0m 0s 0ms"
|
||||
totalMs := backup.BackupDurationMs
|
||||
minutes := totalMs / (1000 * 60)
|
||||
seconds := (totalMs % (1000 * 60)) / 1000
|
||||
durationStr := fmt.Sprintf("%dm %ds", minutes, seconds)
|
||||
|
||||
message = fmt.Sprintf(
|
||||
"Backup completed successfully in %s.\nCompressed backup size: %s",
|
||||
durationStr,
|
||||
sizeStr,
|
||||
)
|
||||
}
|
||||
|
||||
n.notificationSender.SendNotification(
|
||||
¬ifier,
|
||||
title,
|
||||
message,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (n *BackuperNode) sendHeartbeat(backupNode *BackupNode) {
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
backupNode.LastHeartbeat = time.Now().UTC()
|
||||
if err := n.nodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *backupNode); err != nil {
|
||||
n.logger.Error("Failed to send heartbeat", "error", err)
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package backups
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -8,17 +8,15 @@ import (
|
||||
"time"
|
||||
|
||||
common "databasus-backend/internal/features/backups/backups/common"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -26,6 +24,7 @@ import (
|
||||
)
|
||||
|
||||
func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
@@ -50,23 +49,19 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
|
||||
t.Run("BackupFailed_FailNotificationSent", func(t *testing.T) {
|
||||
mockNotificationSender := &MockNotificationSender{}
|
||||
backupService := &BackupService{
|
||||
databases.GetDatabaseService(),
|
||||
storages.GetStorageService(),
|
||||
backupRepository,
|
||||
notifiers.GetNotifierService(),
|
||||
mockNotificationSender,
|
||||
backups_config.GetBackupConfigService(),
|
||||
encryption_secrets.GetSecretKeyService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
&CreateFailedBackupUsecase{},
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
nil,
|
||||
NewBackupContextManager(),
|
||||
nil,
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.notificationSender = mockNotificationSender
|
||||
backuperNode.createBackupUseCase = &CreateFailedBackupUsecase{}
|
||||
|
||||
// Create a backup record directly that will be looked up by MakeBackup
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err := backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Set up expectations
|
||||
mockNotificationSender.On("SendNotification",
|
||||
@@ -79,7 +74,7 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
}),
|
||||
).Once()
|
||||
|
||||
backupService.MakeBackup(database.ID, true)
|
||||
backuperNode.MakeBackup(backup.ID, true)
|
||||
|
||||
// Verify all expectations were met
|
||||
mockNotificationSender.AssertExpectations(t)
|
||||
@@ -87,6 +82,19 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
|
||||
t.Run("BackupSuccess_SuccessNotificationSent", func(t *testing.T) {
|
||||
mockNotificationSender := &MockNotificationSender{}
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.notificationSender = mockNotificationSender
|
||||
backuperNode.createBackupUseCase = &CreateSuccessBackupUsecase{}
|
||||
|
||||
// Create a backup record directly that will be looked up by MakeBackup
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err := backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Set up expectations
|
||||
mockNotificationSender.On("SendNotification",
|
||||
@@ -99,25 +107,7 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
}),
|
||||
).Once()
|
||||
|
||||
backupService := &BackupService{
|
||||
databases.GetDatabaseService(),
|
||||
storages.GetStorageService(),
|
||||
backupRepository,
|
||||
notifiers.GetNotifierService(),
|
||||
mockNotificationSender,
|
||||
backups_config.GetBackupConfigService(),
|
||||
encryption_secrets.GetSecretKeyService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
&CreateSuccessBackupUsecase{},
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
nil,
|
||||
NewBackupContextManager(),
|
||||
nil,
|
||||
}
|
||||
|
||||
backupService.MakeBackup(database.ID, true)
|
||||
backuperNode.MakeBackup(backup.ID, true)
|
||||
|
||||
// Verify all expectations were met
|
||||
mockNotificationSender.AssertExpectations(t)
|
||||
@@ -125,23 +115,19 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
|
||||
t.Run("BackupSuccess_VerifyNotificationContent", func(t *testing.T) {
|
||||
mockNotificationSender := &MockNotificationSender{}
|
||||
backupService := &BackupService{
|
||||
databases.GetDatabaseService(),
|
||||
storages.GetStorageService(),
|
||||
backupRepository,
|
||||
notifiers.GetNotifierService(),
|
||||
mockNotificationSender,
|
||||
backups_config.GetBackupConfigService(),
|
||||
encryption_secrets.GetSecretKeyService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
&CreateSuccessBackupUsecase{},
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
nil,
|
||||
NewBackupContextManager(),
|
||||
nil,
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.notificationSender = mockNotificationSender
|
||||
backuperNode.createBackupUseCase = &CreateSuccessBackupUsecase{}
|
||||
|
||||
// Create a backup record directly that will be looked up by MakeBackup
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err := backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// capture arguments
|
||||
var capturedNotifier *notifiers.Notifier
|
||||
@@ -158,7 +144,7 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
capturedMessage = args.Get(2).(string)
|
||||
}).Once()
|
||||
|
||||
backupService.MakeBackup(database.ID, true)
|
||||
backuperNode.MakeBackup(backup.ID, true)
|
||||
|
||||
// Verify expectations were met
|
||||
mockNotificationSender.AssertExpectations(t)
|
||||
77
backend/internal/features/backups/backups/backuping/di.go
Normal file
77
backend/internal/features/backups/backups/backuping/di.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/usecases"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var backupRepository = &backups_core.BackupRepository{}
|
||||
|
||||
var backupCancelManager = backups_cancellation.GetBackupCancelManager()
|
||||
|
||||
var nodesRegistry = &BackupNodesRegistry{
|
||||
cache_utils.GetValkeyClient(),
|
||||
logger.GetLogger(),
|
||||
cache_utils.DefaultCacheTimeout,
|
||||
cache_utils.NewPubSubManager(),
|
||||
cache_utils.NewPubSubManager(),
|
||||
}
|
||||
|
||||
func getNodeID() uuid.UUID {
|
||||
nodeIDStr := config.GetEnv().NodeID
|
||||
nodeID, err := uuid.Parse(nodeIDStr)
|
||||
if err != nil {
|
||||
logger.GetLogger().Error("Failed to parse node ID from config", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
return nodeID
|
||||
}
|
||||
|
||||
var backuperNode = &BackuperNode{
|
||||
databases.GetDatabaseService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
backupCancelManager,
|
||||
nodesRegistry,
|
||||
logger.GetLogger(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
getNodeID(),
|
||||
time.Time{},
|
||||
}
|
||||
|
||||
var backupsScheduler = &BackupsScheduler{
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
backupCancelManager,
|
||||
nodesRegistry,
|
||||
time.Now().UTC(),
|
||||
logger.GetLogger(),
|
||||
make(map[uuid.UUID]BackupToNodeRelation),
|
||||
backuperNode,
|
||||
}
|
||||
|
||||
func GetBackupsScheduler() *BackupsScheduler {
|
||||
return backupsScheduler
|
||||
}
|
||||
|
||||
func GetBackuperNode() *BackuperNode {
|
||||
return backuperNode
|
||||
}
|
||||
34
backend/internal/features/backups/backups/backuping/dto.go
Normal file
34
backend/internal/features/backups/backups/backuping/dto.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type BackupNode struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
ThroughputMBs int `json:"throughputMBs"`
|
||||
LastHeartbeat time.Time `json:"lastHeartbeat"`
|
||||
}
|
||||
|
||||
type BackupNodeStats struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
ActiveBackups int `json:"activeBackups"`
|
||||
}
|
||||
|
||||
type BackupSubmitMessage struct {
|
||||
NodeID string `json:"nodeId"`
|
||||
BackupID string `json:"backupId"`
|
||||
IsCallNotifier bool `json:"isCallNotifier"`
|
||||
}
|
||||
|
||||
type BackupCompletionMessage struct {
|
||||
NodeID string `json:"nodeId"`
|
||||
BackupID string `json:"backupId"`
|
||||
}
|
||||
|
||||
type BackupToNodeRelation struct {
|
||||
NodeID uuid.UUID `json:"nodeId"`
|
||||
BackupsIDs []uuid.UUID `json:"backupsIds"`
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package backups
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
448
backend/internal/features/backups/backups/backuping/registry.go
Normal file
448
backend/internal/features/backups/backups/backuping/registry.go
Normal file
@@ -0,0 +1,448 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/valkey-io/valkey-go"
|
||||
)
|
||||
|
||||
const (
|
||||
nodeInfoKeyPrefix = "node:"
|
||||
nodeInfoKeySuffix = ":info"
|
||||
nodeActiveBackupsPrefix = "node:"
|
||||
nodeActiveBackupsSuffix = ":active_backups"
|
||||
backupSubmitChannel = "backup:submit"
|
||||
backupCompletionChannel = "backup:completion"
|
||||
)
|
||||
|
||||
// BackupNodesRegistry helps to sync backups scheduler and backup nodes.
|
||||
// Features:
|
||||
// - Track node availability and load level
|
||||
// - Assign from scheduler to node backups needed to be processed
|
||||
// - Notify scheduler from node about backup completion
|
||||
type BackupNodesRegistry struct {
|
||||
client valkey.Client
|
||||
logger *slog.Logger
|
||||
timeout time.Duration
|
||||
pubsubBackups *cache_utils.PubSubManager
|
||||
pubsubCompletions *cache_utils.PubSubManager
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) GetAvailableNodes() ([]BackupNode, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
var allKeys []string
|
||||
cursor := uint64(0)
|
||||
pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix
|
||||
|
||||
for {
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return nil, fmt.Errorf("failed to scan node keys: %w", result.Error())
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse scan result: %w", err)
|
||||
}
|
||||
|
||||
allKeys = append(allKeys, scanResult.Elements...)
|
||||
|
||||
cursor = scanResult.Cursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(allKeys) == 0 {
|
||||
return []BackupNode{}, nil
|
||||
}
|
||||
|
||||
keyDataMap, err := r.pipelineGetKeys(allKeys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pipeline get node keys: %w", err)
|
||||
}
|
||||
|
||||
var nodes []BackupNode
|
||||
for key, data := range keyDataMap {
|
||||
var node BackupNode
|
||||
if err := json.Unmarshal(data, &node); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal node data", "key", key, "error", err)
|
||||
continue
|
||||
}
|
||||
nodes = append(nodes, node)
|
||||
}
|
||||
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) GetBackupNodesStats() ([]BackupNodeStats, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
var allKeys []string
|
||||
cursor := uint64(0)
|
||||
pattern := nodeActiveBackupsPrefix + "*" + nodeActiveBackupsSuffix
|
||||
|
||||
for {
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(100).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return nil, fmt.Errorf("failed to scan active backups keys: %w", result.Error())
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse scan result: %w", err)
|
||||
}
|
||||
|
||||
allKeys = append(allKeys, scanResult.Elements...)
|
||||
|
||||
cursor = scanResult.Cursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(allKeys) == 0 {
|
||||
return []BackupNodeStats{}, nil
|
||||
}
|
||||
|
||||
keyDataMap, err := r.pipelineGetKeys(allKeys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pipeline get active backups keys: %w", err)
|
||||
}
|
||||
|
||||
var stats []BackupNodeStats
|
||||
for key, data := range keyDataMap {
|
||||
nodeID := r.extractNodeIDFromKey(key, nodeActiveBackupsPrefix, nodeActiveBackupsSuffix)
|
||||
|
||||
count, err := r.parseIntFromBytes(data)
|
||||
if err != nil {
|
||||
r.logger.Warn("Failed to parse active backups count", "key", key, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
stat := BackupNodeStats{
|
||||
ID: nodeID,
|
||||
ActiveBackups: int(count),
|
||||
}
|
||||
stats = append(stats, stat)
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) IncrementBackupsInProgress(nodeID string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID, nodeActiveBackupsSuffix)
|
||||
result := r.client.Do(ctx, r.client.B().Incr().Key(key).Build())
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to increment backups in progress for node %s: %w",
|
||||
nodeID,
|
||||
result.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) DecrementBackupsInProgress(nodeID string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID, nodeActiveBackupsSuffix)
|
||||
result := r.client.Do(ctx, r.client.B().Decr().Key(key).Build())
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to decrement backups in progress for node %s: %w",
|
||||
nodeID,
|
||||
result.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
newValue, err := result.AsInt64()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse decremented value for node %s: %w", nodeID, err)
|
||||
}
|
||||
|
||||
if newValue < 0 {
|
||||
setCtx, setCancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
r.client.Do(setCtx, r.client.B().Set().Key(key).Value("0").Build())
|
||||
setCancel()
|
||||
r.logger.Warn("Active backups counter went below 0, reset to 0", "nodeID", nodeID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) HearthbeatNodeInRegistry(now time.Time, backupNode BackupNode) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
backupNode.LastHeartbeat = now
|
||||
|
||||
data, err := json.Marshal(backupNode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal backup node: %w", err)
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, backupNode.ID.String(), nodeInfoKeySuffix)
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Set().Key(key).Value(string(data)).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to register node %s: %w", backupNode.ID, result.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) UnregisterNodeFromRegistry(backupNode BackupNode) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, backupNode.ID.String(), nodeInfoKeySuffix)
|
||||
counterKey := fmt.Sprintf(
|
||||
"%s%s%s",
|
||||
nodeActiveBackupsPrefix,
|
||||
backupNode.ID.String(),
|
||||
nodeActiveBackupsSuffix,
|
||||
)
|
||||
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Del().Key(infoKey, counterKey).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to unregister node %s: %w", backupNode.ID, result.Error())
|
||||
}
|
||||
|
||||
r.logger.Info("Unregistered node from registry", "nodeID", backupNode.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) AssignBackupToNode(
|
||||
targetNodeID string,
|
||||
backupID uuid.UUID,
|
||||
isCallNotifier bool,
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
message := BackupSubmitMessage{
|
||||
NodeID: targetNodeID,
|
||||
BackupID: backupID.String(),
|
||||
IsCallNotifier: isCallNotifier,
|
||||
}
|
||||
|
||||
messageJSON, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal backup submit message: %w", err)
|
||||
}
|
||||
|
||||
err = r.pubsubBackups.Publish(ctx, backupSubmitChannel, string(messageJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to publish backup submit message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) SubscribeNodeForBackupsAssignment(
|
||||
nodeID string,
|
||||
handler func(backupID uuid.UUID, isCallNotifier bool),
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
wrappedHandler := func(message string) {
|
||||
var msg BackupSubmitMessage
|
||||
if err := json.Unmarshal([]byte(message), &msg); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal backup submit message", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if msg.NodeID != nodeID {
|
||||
return
|
||||
}
|
||||
|
||||
backupID, err := uuid.Parse(msg.BackupID)
|
||||
if err != nil {
|
||||
r.logger.Warn(
|
||||
"Failed to parse backup ID from message",
|
||||
"backupId",
|
||||
msg.BackupID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
handler(backupID, msg.IsCallNotifier)
|
||||
}
|
||||
|
||||
err := r.pubsubBackups.Subscribe(ctx, backupSubmitChannel, wrappedHandler)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to backup submit channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Subscribed to backup submit channel", "nodeID", nodeID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) UnsubscribeNodeForBackupsAssignments() error {
|
||||
err := r.pubsubBackups.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unsubscribe from backup submit channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Unsubscribed from backup submit channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) PublishBackupCompletion(nodeID string, backupID uuid.UUID) error {
|
||||
ctx := context.Background()
|
||||
|
||||
message := BackupCompletionMessage{
|
||||
NodeID: nodeID,
|
||||
BackupID: backupID.String(),
|
||||
}
|
||||
|
||||
messageJSON, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal backup completion message: %w", err)
|
||||
}
|
||||
|
||||
err = r.pubsubCompletions.Publish(ctx, backupCompletionChannel, string(messageJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to publish backup completion message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) SubscribeForBackupsCompletions(
|
||||
handler func(nodeID string, backupID uuid.UUID),
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
wrappedHandler := func(message string) {
|
||||
var msg BackupCompletionMessage
|
||||
if err := json.Unmarshal([]byte(message), &msg); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal backup completion message", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
backupID, err := uuid.Parse(msg.BackupID)
|
||||
if err != nil {
|
||||
r.logger.Warn(
|
||||
"Failed to parse backup ID from completion message",
|
||||
"backupId",
|
||||
msg.BackupID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
handler(msg.NodeID, backupID)
|
||||
}
|
||||
|
||||
err := r.pubsubCompletions.Subscribe(ctx, backupCompletionChannel, wrappedHandler)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to backup completion channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Subscribed to backup completion channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) UnsubscribeForBackupsCompletions() error {
|
||||
err := r.pubsubCompletions.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unsubscribe from backup completion channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Unsubscribed from backup completion channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) extractNodeIDFromKey(key, prefix, suffix string) uuid.UUID {
|
||||
nodeIDStr := strings.TrimPrefix(key, prefix)
|
||||
nodeIDStr = strings.TrimSuffix(nodeIDStr, suffix)
|
||||
|
||||
nodeID, err := uuid.Parse(nodeIDStr)
|
||||
if err != nil {
|
||||
r.logger.Warn("Failed to parse node ID from key", "key", key, "error", err)
|
||||
return uuid.Nil
|
||||
}
|
||||
|
||||
return nodeID
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) pipelineGetKeys(keys []string) (map[string][]byte, error) {
|
||||
if len(keys) == 0 {
|
||||
return make(map[string][]byte), nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
commands := make([]valkey.Completed, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
commands = append(commands, r.client.B().Get().Key(key).Build())
|
||||
}
|
||||
|
||||
results := r.client.DoMulti(ctx, commands...)
|
||||
|
||||
keyDataMap := make(map[string][]byte, len(keys))
|
||||
for i, result := range results {
|
||||
if result.Error() != nil {
|
||||
r.logger.Warn("Failed to get key in pipeline", "key", keys[i], "error", result.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
data, err := result.AsBytes()
|
||||
if err != nil {
|
||||
r.logger.Warn("Failed to parse key data in pipeline", "key", keys[i], "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
keyDataMap[keys[i]] = data
|
||||
}
|
||||
|
||||
return keyDataMap, nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) parseIntFromBytes(data []byte) (int64, error) {
|
||||
str := string(data)
|
||||
var count int64
|
||||
_, err := fmt.Sscanf(str, "%d", &count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to parse integer from bytes: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
@@ -0,0 +1,904 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/logger"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_HearthbeatNodeInRegistry_RegistersNodeWithTTL(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
defer cleanupTestNode(registry, node)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
nodes, err := registry.GetAvailableNodes()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, node.ID, nodes[0].ID)
|
||||
assert.Equal(t, node.ThroughputMBs, nodes[0].ThroughputMBs)
|
||||
}
|
||||
|
||||
func Test_UnregisterNodeFromRegistry_RemovesNodeAndCounter(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.UnregisterNodeFromRegistry(node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
nodes, err := registry.GetAvailableNodes()
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, nodes)
|
||||
|
||||
stats, err := registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, stats)
|
||||
}
|
||||
|
||||
func Test_GetAvailableNodes_ReturnsAllRegisteredNodes(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node1 := createTestBackupNode()
|
||||
node2 := createTestBackupNode()
|
||||
node3 := createTestBackupNode()
|
||||
defer cleanupTestNode(registry, node1)
|
||||
defer cleanupTestNode(registry, node2)
|
||||
defer cleanupTestNode(registry, node3)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
|
||||
assert.NoError(t, err)
|
||||
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
|
||||
assert.NoError(t, err)
|
||||
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
nodes, err := registry.GetAvailableNodes()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, nodes, 3)
|
||||
|
||||
nodeIDs := make(map[uuid.UUID]bool)
|
||||
for _, node := range nodes {
|
||||
nodeIDs[node.ID] = true
|
||||
}
|
||||
assert.True(t, nodeIDs[node1.ID])
|
||||
assert.True(t, nodeIDs[node2.ID])
|
||||
assert.True(t, nodeIDs[node3.ID])
|
||||
}
|
||||
|
||||
func Test_GetAvailableNodes_WhenNoNodesExist_ReturnsEmptySlice(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
|
||||
nodes, err := registry.GetAvailableNodes()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, nodes)
|
||||
assert.Empty(t, nodes)
|
||||
}
|
||||
|
||||
func Test_IncrementBackupsInProgress_IncrementsCounter(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
defer cleanupTestNode(registry, node)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err := registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, stats, 1)
|
||||
assert.Equal(t, node.ID, stats[0].ID)
|
||||
assert.Equal(t, 1, stats[0].ActiveBackups)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err = registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, stats, 1)
|
||||
assert.Equal(t, 2, stats[0].ActiveBackups)
|
||||
}
|
||||
|
||||
func Test_DecrementBackupsInProgress_DecrementsCounter(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
defer cleanupTestNode(registry, node)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node.ID.String())
|
||||
assert.NoError(t, err)
|
||||
err = registry.IncrementBackupsInProgress(node.ID.String())
|
||||
assert.NoError(t, err)
|
||||
err = registry.IncrementBackupsInProgress(node.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err := registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, stats[0].ActiveBackups)
|
||||
|
||||
err = registry.DecrementBackupsInProgress(node.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err = registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, stats[0].ActiveBackups)
|
||||
|
||||
err = registry.DecrementBackupsInProgress(node.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err = registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, stats[0].ActiveBackups)
|
||||
}
|
||||
|
||||
func Test_DecrementBackupsInProgress_WhenNegative_ResetsToZero(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
defer cleanupTestNode(registry, node)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.DecrementBackupsInProgress(node.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err := registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, stats, 1)
|
||||
assert.Equal(t, 0, stats[0].ActiveBackups)
|
||||
}
|
||||
|
||||
func Test_GetBackupNodesStats_ReturnsStatsForAllNodes(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node1 := createTestBackupNode()
|
||||
node2 := createTestBackupNode()
|
||||
node3 := createTestBackupNode()
|
||||
defer cleanupTestNode(registry, node1)
|
||||
defer cleanupTestNode(registry, node2)
|
||||
defer cleanupTestNode(registry, node3)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
|
||||
assert.NoError(t, err)
|
||||
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
|
||||
assert.NoError(t, err)
|
||||
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node1.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node2.ID.String())
|
||||
assert.NoError(t, err)
|
||||
err = registry.IncrementBackupsInProgress(node2.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node3.ID.String())
|
||||
assert.NoError(t, err)
|
||||
err = registry.IncrementBackupsInProgress(node3.ID.String())
|
||||
assert.NoError(t, err)
|
||||
err = registry.IncrementBackupsInProgress(node3.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err := registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, stats, 3)
|
||||
|
||||
statsMap := make(map[uuid.UUID]int)
|
||||
for _, stat := range stats {
|
||||
statsMap[stat.ID] = stat.ActiveBackups
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, statsMap[node1.ID])
|
||||
assert.Equal(t, 2, statsMap[node2.ID])
|
||||
assert.Equal(t, 3, statsMap[node3.ID])
|
||||
}
|
||||
|
||||
func Test_GetBackupNodesStats_WhenNoStats_ReturnsEmptySlice(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
|
||||
stats, err := registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, stats)
|
||||
assert.Empty(t, stats)
|
||||
}
|
||||
|
||||
func Test_MultipleNodes_RegisteredAndQueriedCorrectly(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node1 := createTestBackupNode()
|
||||
node1.ThroughputMBs = 50
|
||||
node2 := createTestBackupNode()
|
||||
node2.ThroughputMBs = 100
|
||||
node3 := createTestBackupNode()
|
||||
node3.ThroughputMBs = 150
|
||||
defer cleanupTestNode(registry, node1)
|
||||
defer cleanupTestNode(registry, node2)
|
||||
defer cleanupTestNode(registry, node3)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
|
||||
assert.NoError(t, err)
|
||||
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
|
||||
assert.NoError(t, err)
|
||||
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
nodes, err := registry.GetAvailableNodes()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, nodes, 3)
|
||||
|
||||
nodeMap := make(map[uuid.UUID]BackupNode)
|
||||
for _, node := range nodes {
|
||||
nodeMap[node.ID] = node
|
||||
}
|
||||
|
||||
assert.Equal(t, 50, nodeMap[node1.ID].ThroughputMBs)
|
||||
assert.Equal(t, 100, nodeMap[node2.ID].ThroughputMBs)
|
||||
assert.Equal(t, 150, nodeMap[node3.ID].ThroughputMBs)
|
||||
}
|
||||
|
||||
func Test_BackupCounters_TrackedSeparatelyPerNode(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node1 := createTestBackupNode()
|
||||
node2 := createTestBackupNode()
|
||||
defer cleanupTestNode(registry, node1)
|
||||
defer cleanupTestNode(registry, node2)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
|
||||
assert.NoError(t, err)
|
||||
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node1.ID.String())
|
||||
assert.NoError(t, err)
|
||||
err = registry.IncrementBackupsInProgress(node1.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node2.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err := registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, stats, 2)
|
||||
|
||||
statsMap := make(map[uuid.UUID]int)
|
||||
for _, stat := range stats {
|
||||
statsMap[stat.ID] = stat.ActiveBackups
|
||||
}
|
||||
|
||||
assert.Equal(t, 2, statsMap[node1.ID])
|
||||
assert.Equal(t, 1, statsMap[node2.ID])
|
||||
|
||||
err = registry.DecrementBackupsInProgress(node1.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err = registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
|
||||
statsMap = make(map[uuid.UUID]int)
|
||||
for _, stat := range stats {
|
||||
statsMap[stat.ID] = stat.ActiveBackups
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, statsMap[node1.ID])
|
||||
assert.Equal(t, 1, statsMap[node2.ID])
|
||||
}
|
||||
|
||||
func Test_GetAvailableNodes_SkipsInvalidJsonData(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
defer cleanupTestNode(registry, node)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
defer cancel()
|
||||
|
||||
invalidKey := nodeInfoKeyPrefix + uuid.New().String() + nodeInfoKeySuffix
|
||||
registry.client.Do(
|
||||
ctx,
|
||||
registry.client.B().Set().Key(invalidKey).Value("invalid json data").Build(),
|
||||
)
|
||||
defer func() {
|
||||
cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
defer cleanupCancel()
|
||||
registry.client.Do(cleanupCtx, registry.client.B().Del().Key(invalidKey).Build())
|
||||
}()
|
||||
|
||||
nodes, err := registry.GetAvailableNodes()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, node.ID, nodes[0].ID)
|
||||
}
|
||||
|
||||
func Test_PipelineGetKeys_HandlesEmptyKeysList(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
|
||||
keyDataMap, err := registry.pipelineGetKeys([]string{})
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, keyDataMap)
|
||||
assert.Empty(t, keyDataMap)
|
||||
}
|
||||
|
||||
func Test_HearthbeatNodeInRegistry_UpdatesLastHeartbeat(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
originalHeartbeat := node.LastHeartbeat
|
||||
defer cleanupTestNode(registry, node)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
nodes, err := registry.GetAvailableNodes()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.True(t, nodes[0].LastHeartbeat.After(originalHeartbeat))
|
||||
}
|
||||
|
||||
func createTestRegistry() *BackupNodesRegistry {
|
||||
return &BackupNodesRegistry{
|
||||
cache_utils.GetValkeyClient(),
|
||||
logger.GetLogger(),
|
||||
cache_utils.DefaultCacheTimeout,
|
||||
cache_utils.NewPubSubManager(),
|
||||
cache_utils.NewPubSubManager(),
|
||||
}
|
||||
}
|
||||
|
||||
func createTestBackupNode() BackupNode {
|
||||
return BackupNode{
|
||||
ID: uuid.New(),
|
||||
ThroughputMBs: 100,
|
||||
LastHeartbeat: time.Now().UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
func cleanupTestNode(registry *BackupNodesRegistry, node BackupNode) {
|
||||
registry.UnregisterNodeFromRegistry(node)
|
||||
}
|
||||
|
||||
func Test_AssignBackupTonode_PublishesJsonMessageToChannel(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
backupID := uuid.New()
|
||||
|
||||
err := registry.AssignBackupToNode(node.ID.String(), backupID, true)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_SubscribeNodeForBackupsAssignment_ReceivesSubmittedBackupsForMatchingNode(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
backupID := uuid.New()
|
||||
defer registry.UnsubscribeNodeForBackupsAssignments()
|
||||
|
||||
receivedBackupID := make(chan uuid.UUID, 1)
|
||||
handler := func(id uuid.UUID, isCallNotifier bool) {
|
||||
receivedBackupID <- id
|
||||
}
|
||||
|
||||
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.AssignBackupToNode(node.ID.String(), backupID, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case received := <-receivedBackupID:
|
||||
assert.Equal(t, backupID, received)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Timeout waiting for backup message")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SubscribeNodeForBackupsAssignment_FiltersOutBackupsForDifferentNode(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node1 := createTestBackupNode()
|
||||
node2 := createTestBackupNode()
|
||||
backupID := uuid.New()
|
||||
defer registry.UnsubscribeNodeForBackupsAssignments()
|
||||
|
||||
receivedBackupID := make(chan uuid.UUID, 1)
|
||||
handler := func(id uuid.UUID, isCallNotifier bool) {
|
||||
receivedBackupID <- id
|
||||
}
|
||||
|
||||
err := registry.SubscribeNodeForBackupsAssignment(node1.ID.String(), handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.AssignBackupToNode(node2.ID.String(), backupID, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-receivedBackupID:
|
||||
t.Fatal("Should not receive backup for different node")
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SubscribeNodeForBackupsAssignment_ParsesJsonAndBackupIdCorrectly(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
backupID1 := uuid.New()
|
||||
backupID2 := uuid.New()
|
||||
defer registry.UnsubscribeNodeForBackupsAssignments()
|
||||
|
||||
receivedBackups := make(chan uuid.UUID, 2)
|
||||
handler := func(id uuid.UUID, isCallNotifier bool) {
|
||||
receivedBackups <- id
|
||||
}
|
||||
|
||||
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.AssignBackupToNode(node.ID.String(), backupID1, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.AssignBackupToNode(node.ID.String(), backupID2, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
received1 := <-receivedBackups
|
||||
received2 := <-receivedBackups
|
||||
|
||||
receivedIDs := []uuid.UUID{received1, received2}
|
||||
assert.Contains(t, receivedIDs, backupID1)
|
||||
assert.Contains(t, receivedIDs, backupID2)
|
||||
}
|
||||
|
||||
func Test_SubscribeNodeForBackupsAssignment_HandlesInvalidJson(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
defer registry.UnsubscribeNodeForBackupsAssignments()
|
||||
|
||||
receivedBackupID := make(chan uuid.UUID, 1)
|
||||
handler := func(id uuid.UUID, isCallNotifier bool) {
|
||||
receivedBackupID <- id
|
||||
}
|
||||
|
||||
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
ctx := context.Background()
|
||||
err = registry.pubsubBackups.Publish(ctx, "backup:submit", "invalid json")
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-receivedBackupID:
|
||||
t.Fatal("Should not receive backup for invalid JSON")
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func Test_UnsubscribeNodeForBackupsAssignments_StopsReceivingMessages(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
backupID1 := uuid.New()
|
||||
backupID2 := uuid.New()
|
||||
|
||||
receivedBackupID := make(chan uuid.UUID, 2)
|
||||
handler := func(id uuid.UUID, isCallNotifier bool) {
|
||||
receivedBackupID <- id
|
||||
}
|
||||
|
||||
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.AssignBackupToNode(node.ID.String(), backupID1, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
received := <-receivedBackupID
|
||||
assert.Equal(t, backupID1, received)
|
||||
|
||||
err = registry.UnsubscribeNodeForBackupsAssignments()
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.AssignBackupToNode(node.ID.String(), backupID2, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-receivedBackupID:
|
||||
t.Fatal("Should not receive backup after unsubscribe")
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SubscribeNodeForBackupsAssignment_WhenAlreadySubscribed_ReturnsError(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
defer registry.UnsubscribeNodeForBackupsAssignments()
|
||||
|
||||
handler := func(id uuid.UUID, isCallNotifier bool) {}
|
||||
|
||||
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "already subscribed")
|
||||
}
|
||||
|
||||
func Test_MultipleNodes_EachReceivesOnlyTheirBackups(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry1 := createTestRegistry()
|
||||
registry2 := createTestRegistry()
|
||||
registry3 := createTestRegistry()
|
||||
|
||||
node1 := createTestBackupNode()
|
||||
node2 := createTestBackupNode()
|
||||
node3 := createTestBackupNode()
|
||||
|
||||
backupID1 := uuid.New()
|
||||
backupID2 := uuid.New()
|
||||
backupID3 := uuid.New()
|
||||
|
||||
defer registry1.UnsubscribeNodeForBackupsAssignments()
|
||||
defer registry2.UnsubscribeNodeForBackupsAssignments()
|
||||
defer registry3.UnsubscribeNodeForBackupsAssignments()
|
||||
|
||||
receivedBackups1 := make(chan uuid.UUID, 3)
|
||||
receivedBackups2 := make(chan uuid.UUID, 3)
|
||||
receivedBackups3 := make(chan uuid.UUID, 3)
|
||||
|
||||
handler1 := func(id uuid.UUID, isCallNotifier bool) { receivedBackups1 <- id }
|
||||
handler2 := func(id uuid.UUID, isCallNotifier bool) { receivedBackups2 <- id }
|
||||
handler3 := func(id uuid.UUID, isCallNotifier bool) { receivedBackups3 <- id }
|
||||
|
||||
err := registry1.SubscribeNodeForBackupsAssignment(node1.ID.String(), handler1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry2.SubscribeNodeForBackupsAssignment(node2.ID.String(), handler2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry3.SubscribeNodeForBackupsAssignment(node3.ID.String(), handler3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
submitRegistry := createTestRegistry()
|
||||
err = submitRegistry.AssignBackupToNode(node1.ID.String(), backupID1, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = submitRegistry.AssignBackupToNode(node2.ID.String(), backupID2, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = submitRegistry.AssignBackupToNode(node3.ID.String(), backupID3, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case received := <-receivedBackups1:
|
||||
assert.Equal(t, backupID1, received)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Node 1 timeout waiting for backup message")
|
||||
}
|
||||
|
||||
select {
|
||||
case received := <-receivedBackups2:
|
||||
assert.Equal(t, backupID2, received)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Node 2 timeout waiting for backup message")
|
||||
}
|
||||
|
||||
select {
|
||||
case received := <-receivedBackups3:
|
||||
assert.Equal(t, backupID3, received)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Node 3 timeout waiting for backup message")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-receivedBackups1:
|
||||
t.Fatal("Node 1 should not receive additional backups")
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
}
|
||||
|
||||
select {
|
||||
case <-receivedBackups2:
|
||||
t.Fatal("Node 2 should not receive additional backups")
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
}
|
||||
|
||||
select {
|
||||
case <-receivedBackups3:
|
||||
t.Fatal("Node 3 should not receive additional backups")
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func Test_PublishBackupCompletion_PublishesMessageToChannel(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
backupID := uuid.New()
|
||||
|
||||
err := registry.PublishBackupCompletion(node.ID.String(), backupID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_SubscribeForBackupsCompletions_ReceivesCompletedBackups(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
backupID := uuid.New()
|
||||
defer registry.UnsubscribeForBackupsCompletions()
|
||||
|
||||
receivedBackupID := make(chan uuid.UUID, 1)
|
||||
receivedNodeID := make(chan string, 1)
|
||||
handler := func(nodeID string, backupID uuid.UUID) {
|
||||
receivedNodeID <- nodeID
|
||||
receivedBackupID <- backupID
|
||||
}
|
||||
|
||||
err := registry.SubscribeForBackupsCompletions(handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.PublishBackupCompletion(node.ID.String(), backupID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case receivedNode := <-receivedNodeID:
|
||||
assert.Equal(t, node.ID.String(), receivedNode)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Timeout waiting for node ID")
|
||||
}
|
||||
|
||||
select {
|
||||
case received := <-receivedBackupID:
|
||||
assert.Equal(t, backupID, received)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Timeout waiting for backup completion message")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SubscribeForBackupsCompletions_ParsesJsonCorrectly(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
backupID1 := uuid.New()
|
||||
backupID2 := uuid.New()
|
||||
defer registry.UnsubscribeForBackupsCompletions()
|
||||
|
||||
receivedBackups := make(chan uuid.UUID, 2)
|
||||
handler := func(nodeID string, backupID uuid.UUID) {
|
||||
receivedBackups <- backupID
|
||||
}
|
||||
|
||||
err := registry.SubscribeForBackupsCompletions(handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.PublishBackupCompletion(node.ID.String(), backupID1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.PublishBackupCompletion(node.ID.String(), backupID2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
received1 := <-receivedBackups
|
||||
received2 := <-receivedBackups
|
||||
|
||||
receivedIDs := []uuid.UUID{received1, received2}
|
||||
assert.Contains(t, receivedIDs, backupID1)
|
||||
assert.Contains(t, receivedIDs, backupID2)
|
||||
}
|
||||
|
||||
func Test_SubscribeForBackupsCompletions_HandlesInvalidJson(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
defer registry.UnsubscribeForBackupsCompletions()
|
||||
|
||||
receivedBackupID := make(chan uuid.UUID, 1)
|
||||
handler := func(nodeID string, backupID uuid.UUID) {
|
||||
receivedBackupID <- backupID
|
||||
}
|
||||
|
||||
err := registry.SubscribeForBackupsCompletions(handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
ctx := context.Background()
|
||||
err = registry.pubsubCompletions.Publish(ctx, "backup:completion", "invalid json")
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-receivedBackupID:
|
||||
t.Fatal("Should not receive backup for invalid JSON")
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func Test_UnsubscribeForBackupsCompletions_StopsReceivingMessages(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
backupID1 := uuid.New()
|
||||
backupID2 := uuid.New()
|
||||
|
||||
receivedBackupID := make(chan uuid.UUID, 2)
|
||||
handler := func(nodeID string, backupID uuid.UUID) {
|
||||
receivedBackupID <- backupID
|
||||
}
|
||||
|
||||
err := registry.SubscribeForBackupsCompletions(handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.PublishBackupCompletion(node.ID.String(), backupID1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
received := <-receivedBackupID
|
||||
assert.Equal(t, backupID1, received)
|
||||
|
||||
err = registry.UnsubscribeForBackupsCompletions()
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.PublishBackupCompletion(node.ID.String(), backupID2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-receivedBackupID:
|
||||
t.Fatal("Should not receive backup after unsubscribe")
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SubscribeForBackupsCompletions_WhenAlreadySubscribed_ReturnsError(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
defer registry.UnsubscribeForBackupsCompletions()
|
||||
|
||||
handler := func(nodeID string, backupID uuid.UUID) {}
|
||||
|
||||
err := registry.SubscribeForBackupsCompletions(handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.SubscribeForBackupsCompletions(handler)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "already subscribed")
|
||||
}
|
||||
|
||||
func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry1 := createTestRegistry()
|
||||
registry2 := createTestRegistry()
|
||||
registry3 := createTestRegistry()
|
||||
|
||||
node1 := createTestBackupNode()
|
||||
node2 := createTestBackupNode()
|
||||
node3 := createTestBackupNode()
|
||||
|
||||
backupID1 := uuid.New()
|
||||
backupID2 := uuid.New()
|
||||
backupID3 := uuid.New()
|
||||
|
||||
defer registry1.UnsubscribeForBackupsCompletions()
|
||||
defer registry2.UnsubscribeForBackupsCompletions()
|
||||
defer registry3.UnsubscribeForBackupsCompletions()
|
||||
|
||||
receivedBackups1 := make(chan uuid.UUID, 3)
|
||||
receivedBackups2 := make(chan uuid.UUID, 3)
|
||||
receivedBackups3 := make(chan uuid.UUID, 3)
|
||||
|
||||
handler1 := func(nodeID string, backupID uuid.UUID) { receivedBackups1 <- backupID }
|
||||
handler2 := func(nodeID string, backupID uuid.UUID) { receivedBackups2 <- backupID }
|
||||
handler3 := func(nodeID string, backupID uuid.UUID) { receivedBackups3 <- backupID }
|
||||
|
||||
err := registry1.SubscribeForBackupsCompletions(handler1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry2.SubscribeForBackupsCompletions(handler2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry3.SubscribeForBackupsCompletions(handler3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
publishRegistry := createTestRegistry()
|
||||
err = publishRegistry.PublishBackupCompletion(node1.ID.String(), backupID1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = publishRegistry.PublishBackupCompletion(node2.ID.String(), backupID2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = publishRegistry.PublishBackupCompletion(node3.ID.String(), backupID3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
receivedAll1 := []uuid.UUID{}
|
||||
receivedAll2 := []uuid.UUID{}
|
||||
receivedAll3 := []uuid.UUID{}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
select {
|
||||
case received := <-receivedBackups1:
|
||||
receivedAll1 = append(receivedAll1, received)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Subscriber 1 timeout waiting for completion message")
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
select {
|
||||
case received := <-receivedBackups2:
|
||||
receivedAll2 = append(receivedAll2, received)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Subscriber 2 timeout waiting for completion message")
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
select {
|
||||
case received := <-receivedBackups3:
|
||||
receivedAll3 = append(receivedAll3, received)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Subscriber 3 timeout waiting for completion message")
|
||||
}
|
||||
}
|
||||
|
||||
assert.Contains(t, receivedAll1, backupID1)
|
||||
assert.Contains(t, receivedAll1, backupID2)
|
||||
assert.Contains(t, receivedAll1, backupID3)
|
||||
|
||||
assert.Contains(t, receivedAll2, backupID1)
|
||||
assert.Contains(t, receivedAll2, backupID2)
|
||||
assert.Contains(t, receivedAll2, backupID3)
|
||||
|
||||
assert.Contains(t, receivedAll3, backupID1)
|
||||
assert.Contains(t, receivedAll3, backupID2)
|
||||
assert.Contains(t, receivedAll3, backupID3)
|
||||
}
|
||||
600
backend/internal/features/backups/backups/backuping/scheduler.go
Normal file
600
backend/internal/features/backups/backups/backuping/scheduler.go
Normal file
@@ -0,0 +1,600 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"databasus-backend/internal/config"
|
||||
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/storages"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/period"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type BackupsScheduler struct {
|
||||
backupRepository *backups_core.BackupRepository
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
storageService *storages.StorageService
|
||||
backupCancelManager *backups_cancellation.BackupCancelManager
|
||||
nodesRegistry *BackupNodesRegistry
|
||||
|
||||
lastBackupTime time.Time
|
||||
logger *slog.Logger
|
||||
|
||||
backupToNodeRelations map[uuid.UUID]BackupToNodeRelation
|
||||
backuperNode *BackuperNode
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) Run(ctx context.Context) {
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
|
||||
if config.GetEnv().IsManyNodesMode {
|
||||
// wait other nodes to start
|
||||
time.Sleep(1 * time.Minute)
|
||||
}
|
||||
|
||||
if err := s.failBackupsInProgress(); err != nil {
|
||||
s.logger.Error("Failed to fail backups in progress", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if err := s.nodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted); err != nil {
|
||||
s.logger.Error("Failed to subscribe to backup completions", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := s.nodesRegistry.UnsubscribeForBackupsCompletions(); err != nil {
|
||||
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.cleanOldBackups(); err != nil {
|
||||
s.logger.Error("Failed to clean old backups", "error", err)
|
||||
}
|
||||
|
||||
if err := s.checkDeadNodesAndFailBackups(); err != nil {
|
||||
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
|
||||
}
|
||||
|
||||
if err := s.runPendingBackups(); err != nil {
|
||||
s.logger.Error("Failed to run pending backups", "error", err)
|
||||
}
|
||||
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) IsSchedulerRunning() bool {
|
||||
// if last backup time is more than 5 minutes ago, return false
|
||||
return s.lastBackupTime.After(time.Now().UTC().Add(-5 * time.Minute))
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) failBackupsInProgress() error {
|
||||
backupsInProgress, err := s.backupRepository.FindByStatus(backups_core.BackupStatusInProgress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Println("Backups in progress", len(backupsInProgress))
|
||||
|
||||
for _, backup := range backupsInProgress {
|
||||
if err := s.backupCancelManager.CancelBackup(backup.ID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to cancel backup via context manager",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(backup.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
failMessage := "Backup failed due to application restart"
|
||||
backup.FailMessage = &failMessage
|
||||
backup.Status = backups_core.BackupStatusFailed
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
s.backuperNode.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupFailed,
|
||||
&failMessage,
|
||||
)
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool) {
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(databaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if backupConfig.StorageID == nil {
|
||||
s.logger.Error("Backup config storage ID is nil", "databaseId", databaseID)
|
||||
return
|
||||
}
|
||||
|
||||
leastBusyNodeID, err := s.calculateLeastBusyNode()
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to calculate least busy node",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: backupConfig.DatabaseID,
|
||||
StorageID: *backupConfig.StorageID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
BackupSizeMb: 0,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to save backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.nodesRegistry.IncrementBackupsInProgress(leastBusyNodeID.String()); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to increment backups in progress",
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.nodesRegistry.AssignBackupToNode(leastBusyNodeID.String(), backup.ID, isCallNotifier); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to submit backup",
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
if decrementErr := s.nodesRegistry.DecrementBackupsInProgress(leastBusyNodeID.String()); decrementErr != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement backups in progress after submit failure",
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
"error",
|
||||
decrementErr,
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if relation, exists := s.backupToNodeRelations[*leastBusyNodeID]; exists {
|
||||
relation.BackupsIDs = append(relation.BackupsIDs, backup.ID)
|
||||
s.backupToNodeRelations[*leastBusyNodeID] = relation
|
||||
} else {
|
||||
s.backupToNodeRelations[*leastBusyNodeID] = BackupToNodeRelation{
|
||||
NodeID: *leastBusyNodeID,
|
||||
BackupsIDs: []uuid.UUID{backup.ID},
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Info(
|
||||
"Successfully triggered scheduled backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
)
|
||||
}
|
||||
|
||||
// GetRemainedBackupTryCount returns the number of remaining backup tries for a given backup.
|
||||
// If the backup is not failed or the backup config does not allow retries, it returns 0.
|
||||
// If the backup is failed and the backup config allows retries, it returns the number of remaining tries.
|
||||
// If the backup is failed and the backup config does not allow retries, it returns 0.
|
||||
func (s *BackupsScheduler) GetRemainedBackupTryCount(lastBackup *backups_core.Backup) int {
|
||||
if lastBackup == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
if lastBackup.Status != backups_core.BackupStatusFailed {
|
||||
return 0
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(lastBackup.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
if !backupConfig.IsRetryIfFailed {
|
||||
return 0
|
||||
}
|
||||
|
||||
maxFailedTriesCount := backupConfig.MaxFailedTriesCount
|
||||
|
||||
lastBackups, err := s.backupRepository.FindByDatabaseIDWithLimit(
|
||||
lastBackup.DatabaseID,
|
||||
maxFailedTriesCount,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to find last backups by database ID", "error", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
lastFailedBackups := make([]*backups_core.Backup, 0)
|
||||
|
||||
for _, backup := range lastBackups {
|
||||
if backup.Status == backups_core.BackupStatusFailed {
|
||||
lastFailedBackups = append(lastFailedBackups, backup)
|
||||
}
|
||||
}
|
||||
|
||||
return maxFailedTriesCount - len(lastFailedBackups)
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) cleanOldBackups() error {
|
||||
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
backupStorePeriod := backupConfig.StorePeriod
|
||||
|
||||
if backupStorePeriod == period.PeriodForever {
|
||||
continue
|
||||
}
|
||||
|
||||
storeDuration := backupStorePeriod.ToDuration()
|
||||
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
|
||||
|
||||
oldBackups, err := s.backupRepository.FindBackupsBeforeDate(
|
||||
backupConfig.DatabaseID,
|
||||
dateBeforeBackupsShouldBeDeleted,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to find old backups for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, backup := range oldBackups {
|
||||
storage, err := s.storageService.GetStorageByID(backup.StorageID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to get storage by ID",
|
||||
"storageId",
|
||||
backup.StorageID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
err = storage.DeleteFile(encryptor, backup.ID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to delete backup file", "backupId", backup.ID, "error", err)
|
||||
}
|
||||
|
||||
if err := s.backupRepository.DeleteByID(backup.ID); err != nil {
|
||||
s.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
s.logger.Info(
|
||||
"Deleted old backup",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) runPendingBackups() error {
|
||||
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
if backupConfig.BackupInterval == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
lastBackup, err := s.backupRepository.FindLastByDatabaseID(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to get last backup for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
var lastBackupTime *time.Time
|
||||
if lastBackup != nil {
|
||||
lastBackupTime = &lastBackup.CreatedAt
|
||||
}
|
||||
|
||||
remainedBackupTryCount := s.GetRemainedBackupTryCount(lastBackup)
|
||||
|
||||
if backupConfig.BackupInterval.ShouldTriggerBackup(time.Now().UTC(), lastBackupTime) ||
|
||||
remainedBackupTryCount > 0 {
|
||||
s.logger.Info(
|
||||
"Triggering scheduled backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"intervalType",
|
||||
backupConfig.BackupInterval.Interval,
|
||||
)
|
||||
|
||||
s.StartBackup(backupConfig.DatabaseID, remainedBackupTryCount == 1)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
|
||||
nodes, err := s.nodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get available nodes: %w", err)
|
||||
}
|
||||
|
||||
if len(nodes) == 0 {
|
||||
return nil, fmt.Errorf("no nodes available")
|
||||
}
|
||||
|
||||
stats, err := s.nodesRegistry.GetBackupNodesStats()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get backup nodes stats: %w", err)
|
||||
}
|
||||
|
||||
statsMap := make(map[uuid.UUID]int)
|
||||
for _, stat := range stats {
|
||||
statsMap[stat.ID] = stat.ActiveBackups
|
||||
}
|
||||
|
||||
var bestNode *BackupNode
|
||||
var bestScore float64 = -1
|
||||
|
||||
now := time.Now().UTC()
|
||||
for i := range nodes {
|
||||
node := &nodes[i]
|
||||
|
||||
if now.Sub(node.LastHeartbeat) > 2*time.Minute {
|
||||
continue
|
||||
}
|
||||
|
||||
activeBackups := statsMap[node.ID]
|
||||
|
||||
var score float64
|
||||
if node.ThroughputMBs > 0 {
|
||||
score = float64(activeBackups) / float64(node.ThroughputMBs)
|
||||
} else {
|
||||
score = float64(activeBackups) * 1000
|
||||
}
|
||||
|
||||
if bestNode == nil || score < bestScore {
|
||||
bestNode = node
|
||||
bestScore = score
|
||||
}
|
||||
}
|
||||
|
||||
if bestNode == nil {
|
||||
return nil, fmt.Errorf("no suitable nodes available")
|
||||
}
|
||||
|
||||
return &bestNode.ID, nil
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUID) {
|
||||
nodeID, err := uuid.Parse(nodeIDStr)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to parse node ID from completion message",
|
||||
"nodeId",
|
||||
nodeIDStr,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
relation, exists := s.backupToNodeRelations[nodeID]
|
||||
if !exists {
|
||||
s.logger.Warn(
|
||||
"Received completion for unknown node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
newBackupIDs := make([]uuid.UUID, 0)
|
||||
found := false
|
||||
for _, id := range relation.BackupsIDs {
|
||||
if id == backupID {
|
||||
found = true
|
||||
continue
|
||||
}
|
||||
newBackupIDs = append(newBackupIDs, id)
|
||||
}
|
||||
|
||||
if !found {
|
||||
s.logger.Warn(
|
||||
"Backup not found in node's backup list",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if len(newBackupIDs) == 0 {
|
||||
delete(s.backupToNodeRelations, nodeID)
|
||||
} else {
|
||||
relation.BackupsIDs = newBackupIDs
|
||||
s.backupToNodeRelations[nodeID] = relation
|
||||
}
|
||||
|
||||
if err := s.nodesRegistry.DecrementBackupsInProgress(nodeIDStr); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement backups in progress",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) checkDeadNodesAndFailBackups() error {
|
||||
nodes, err := s.nodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get available nodes: %w", err)
|
||||
}
|
||||
|
||||
aliveNodeIDs := make(map[uuid.UUID]bool)
|
||||
now := time.Now().UTC()
|
||||
|
||||
for _, node := range nodes {
|
||||
if now.Sub(node.LastHeartbeat) <= 2*time.Minute {
|
||||
aliveNodeIDs[node.ID] = true
|
||||
}
|
||||
}
|
||||
|
||||
for nodeID, relation := range s.backupToNodeRelations {
|
||||
if aliveNodeIDs[nodeID] {
|
||||
continue
|
||||
}
|
||||
|
||||
s.logger.Warn(
|
||||
"Node is dead, failing its backups",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupCount",
|
||||
len(relation.BackupsIDs),
|
||||
)
|
||||
|
||||
for _, backupID := range relation.BackupsIDs {
|
||||
backup, err := s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to find backup for dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
failMessage := "Backup failed due to node unavailability"
|
||||
backup.FailMessage = &failMessage
|
||||
backup.Status = backups_core.BackupStatusFailed
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to save failed backup for dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := s.nodesRegistry.DecrementBackupsInProgress(nodeID.String()); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement backups in progress for dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
s.logger.Info(
|
||||
"Failed backup due to dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
)
|
||||
}
|
||||
|
||||
delete(s.backupToNodeRelations, nodeID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,709 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/period"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_RunPendingBackups_WhenLastBackupWasYesterday_CreatesNewBackup(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Enable backups for the database
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add old backup
|
||||
backupRepository.Save(&backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupsScheduler().runPendingBackups()
|
||||
|
||||
// Wait for backup to complete (runs in goroutine)
|
||||
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
|
||||
|
||||
// assertions
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 2)
|
||||
|
||||
// Wait for any cleanup operations to complete before defer cleanup runs
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_RunPendingBackups_WhenLastBackupWasRecentlyCompleted_SkipsBackup(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Enable backups for the database
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add recent backup (1 hour ago)
|
||||
backupRepository.Save(&backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupsScheduler().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// assertions
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1) // Should still be 1 backup, no new backup created
|
||||
|
||||
// Wait for any cleanup operations to complete before defer cleanup runs
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_RunPendingBackups_WhenLastBackupFailedAndRetriesDisabled_SkipsBackup(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Enable backups for the database with retries disabled
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
backupConfig.IsRetryIfFailed = false
|
||||
backupConfig.MaxFailedTriesCount = 0
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add failed backup
|
||||
failMessage := "backup failed"
|
||||
backupRepository.Save(&backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: backups_core.BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupsScheduler().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// assertions
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1) // Should still be 1 backup, no retry attempted
|
||||
|
||||
// Wait for any cleanup operations to complete before defer cleanup runs
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_RunPendingBackups_WhenLastBackupFailedAndRetriesEnabled_CreatesNewBackup(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Enable backups for the database with retries enabled
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
backupConfig.IsRetryIfFailed = true
|
||||
backupConfig.MaxFailedTriesCount = 3
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add failed backup
|
||||
failMessage := "backup failed"
|
||||
backupRepository.Save(&backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: backups_core.BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupsScheduler().runPendingBackups()
|
||||
|
||||
// Wait for backup to complete (runs in goroutine)
|
||||
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
|
||||
|
||||
// assertions
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 2) // Should have 2 backups, retry was attempted
|
||||
|
||||
// Wait for any cleanup operations to complete before defer cleanup runs
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_RunPendingBackups_WhenFailedBackupsExceedMaxRetries_SkipsBackup(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
// Enable backups for the database with retries enabled
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
backupConfig.IsRetryIfFailed = true
|
||||
backupConfig.MaxFailedTriesCount = 3
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
failMessage := "backup failed"
|
||||
|
||||
for range 3 {
|
||||
backupRepository.Save(&backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: backups_core.BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
})
|
||||
}
|
||||
|
||||
GetBackupsScheduler().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// assertions
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 3) // Should have 3 backups, not more than max
|
||||
|
||||
// Wait for any cleanup operations to complete before defer cleanup runs
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_RunPendingBackups_WhenBackupsDisabled_SkipsBackup(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = false
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add old backup that would trigger new backup if enabled
|
||||
backupRepository.Save(&backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupsScheduler().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1)
|
||||
|
||||
// Wait for any cleanup operations to complete before defer cleanup runs
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegistry(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Register mock node without subscribing to backups (simulates node crash after registration)
|
||||
mockNodeID := uuid.New()
|
||||
err = CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC())
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Scheduler assigns backup to mock node
|
||||
GetBackupsScheduler().StartBackup(database.ID, false)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1)
|
||||
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status)
|
||||
|
||||
// Verify Valkey counter was incremented when backup was assigned
|
||||
stats, err := nodesRegistry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
foundStat := false
|
||||
for _, stat := range stats {
|
||||
if stat.ID == mockNodeID {
|
||||
assert.Equal(t, 1, stat.ActiveBackups)
|
||||
foundStat = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, foundStat, "Node stats should be present")
|
||||
|
||||
// Simulate node death by setting heartbeat older than 2-minute threshold
|
||||
oldHeartbeat := time.Now().UTC().Add(-3 * time.Minute)
|
||||
err = UpdateNodeHeartbeatDirectly(mockNodeID, 100, oldHeartbeat)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Trigger dead node detection
|
||||
err = GetBackupsScheduler().checkDeadNodesAndFailBackups()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify backup was failed with appropriate error message
|
||||
backups, err = backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1)
|
||||
assert.Equal(t, backups_core.BackupStatusFailed, backups[0].Status)
|
||||
assert.NotNil(t, backups[0].FailMessage)
|
||||
assert.Contains(t, *backups[0].FailMessage, "node unavailability")
|
||||
|
||||
// Verify Valkey counter was decremented after backup failed
|
||||
stats, err = nodesRegistry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
for _, stat := range stats {
|
||||
if stat.ID == mockNodeID {
|
||||
assert.Equal(t, 0, stat.ActiveBackups)
|
||||
}
|
||||
}
|
||||
|
||||
// Node info should still exist in registry (not removed by checkDeadNodesAndFailBackups)
|
||||
node, err := GetNodeFromRegistry(mockNodeID)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, node)
|
||||
assert.Equal(t, mockNodeID, node.ID)
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
|
||||
t.Run("Nodes with same throughput", func(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
node1ID := uuid.New()
|
||||
node2ID := uuid.New()
|
||||
node3ID := uuid.New()
|
||||
now := time.Now().UTC()
|
||||
|
||||
err := CreateMockNodeInRegistry(node1ID, 100, now)
|
||||
assert.NoError(t, err)
|
||||
err = CreateMockNodeInRegistry(node2ID, 100, now)
|
||||
assert.NoError(t, err)
|
||||
err = CreateMockNodeInRegistry(node3ID, 100, now)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for range 5 {
|
||||
err = nodesRegistry.IncrementBackupsInProgress(node1ID.String())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
for range 2 {
|
||||
err = nodesRegistry.IncrementBackupsInProgress(node2ID.String())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
for range 8 {
|
||||
err = nodesRegistry.IncrementBackupsInProgress(node3ID.String())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
leastBusyNodeID, err := GetBackupsScheduler().calculateLeastBusyNode()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, leastBusyNodeID)
|
||||
assert.Equal(t, node2ID, *leastBusyNodeID)
|
||||
})
|
||||
|
||||
t.Run("Nodes with different throughput", func(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
node100MBsID := uuid.New()
|
||||
node50MBsID := uuid.New()
|
||||
now := time.Now().UTC()
|
||||
|
||||
err := CreateMockNodeInRegistry(node100MBsID, 100, now)
|
||||
assert.NoError(t, err)
|
||||
err = CreateMockNodeInRegistry(node50MBsID, 50, now)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for range 10 {
|
||||
err = nodesRegistry.IncrementBackupsInProgress(node100MBsID.String())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
err = nodesRegistry.IncrementBackupsInProgress(node50MBsID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
leastBusyNodeID, err := GetBackupsScheduler().calculateLeastBusyNode()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, leastBusyNodeID)
|
||||
assert.Equal(t, node50MBsID, *leastBusyNodeID)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_FailBackupsInProgress_WhenSchedulerStarts_CancelsBackupsAndUpdatesStatus(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create two in-progress backups that should be failed on scheduler restart
|
||||
backup1 := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
BackupSizeMb: 10.5,
|
||||
CreatedAt: time.Now().UTC().Add(-30 * time.Minute),
|
||||
}
|
||||
err = backupRepository.Save(backup1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backup2 := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
BackupSizeMb: 5.2,
|
||||
CreatedAt: time.Now().UTC().Add(-15 * time.Minute),
|
||||
}
|
||||
err = backupRepository.Save(backup2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create a completed backup to verify it's not affected by failBackupsInProgress
|
||||
completedBackup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 20.0,
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
}
|
||||
err = backupRepository.Save(completedBackup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Trigger the scheduler's failBackupsInProgress logic
|
||||
// This should cancel in-progress backups and mark them as failed
|
||||
err = GetBackupsScheduler().failBackupsInProgress()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify all backups exist and were processed correctly
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 3)
|
||||
|
||||
var failedCount int
|
||||
var completedCount int
|
||||
for _, backup := range backups {
|
||||
switch backup.Status {
|
||||
case backups_core.BackupStatusFailed:
|
||||
failedCount++
|
||||
// Verify fail message indicates application restart
|
||||
assert.NotNil(t, backup.FailMessage)
|
||||
assert.Equal(t, "Backup failed due to application restart", *backup.FailMessage)
|
||||
// Verify backup size was reset to 0
|
||||
assert.Equal(t, float64(0), backup.BackupSizeMb)
|
||||
case backups_core.BackupStatusCompleted:
|
||||
completedCount++
|
||||
}
|
||||
}
|
||||
|
||||
// Verify correct number of backups in each state
|
||||
assert.Equal(t, 2, failedCount)
|
||||
assert.Equal(t, 1, completedCount)
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
206
backend/internal/features/backups/backups/backuping/testing.go
Normal file
206
backend/internal/features/backups/backups/backuping/testing.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/usecases"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func CreateTestRouter() *gin.Engine {
|
||||
router := workspaces_testing.CreateTestRouter(
|
||||
workspaces_controllers.GetWorkspaceController(),
|
||||
workspaces_controllers.GetMembershipController(),
|
||||
databases.GetDatabaseController(),
|
||||
backups_config.GetBackupConfigController(),
|
||||
)
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
func CreateTestBackuperNode() *BackuperNode {
|
||||
return &BackuperNode{
|
||||
databases.GetDatabaseService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
backupCancelManager,
|
||||
nodesRegistry,
|
||||
logger.GetLogger(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
uuid.New(),
|
||||
time.Time{},
|
||||
}
|
||||
}
|
||||
|
||||
// WaitForBackupCompletion waits for a new backup to be created and completed (or failed)
|
||||
// for the given database. It checks for backups with count greater than expectedInitialCount.
|
||||
func WaitForBackupCompletion(
|
||||
t *testing.T,
|
||||
databaseID uuid.UUID,
|
||||
expectedInitialCount int,
|
||||
timeout time.Duration,
|
||||
) {
|
||||
deadline := time.Now().UTC().Add(timeout)
|
||||
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
backups, err := backupRepository.FindByDatabaseID(databaseID)
|
||||
if err != nil {
|
||||
t.Logf("WaitForBackupCompletion: error finding backups: %v", err)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
t.Logf(
|
||||
"WaitForBackupCompletion: found %d backups (expected > %d)",
|
||||
len(backups),
|
||||
expectedInitialCount,
|
||||
)
|
||||
|
||||
if len(backups) > expectedInitialCount {
|
||||
// Check if the newest backup has completed or failed
|
||||
newestBackup := backups[0]
|
||||
t.Logf("WaitForBackupCompletion: newest backup status: %s", newestBackup.Status)
|
||||
|
||||
if newestBackup.Status == backups_core.BackupStatusCompleted ||
|
||||
newestBackup.Status == backups_core.BackupStatusFailed ||
|
||||
newestBackup.Status == backups_core.BackupStatusCanceled {
|
||||
t.Logf(
|
||||
"WaitForBackupCompletion: backup finished with status %s",
|
||||
newestBackup.Status,
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Logf("WaitForBackupCompletion: timeout waiting for backup to complete")
|
||||
}
|
||||
|
||||
// StartBackuperNodeForTest starts a BackuperNode in a goroutine for testing.
|
||||
// The node registers itself in the registry and subscribes to backup assignments.
|
||||
// Returns a context cancel function that should be deferred to stop the node.
|
||||
func StartBackuperNodeForTest(t *testing.T, backuperNode *BackuperNode) context.CancelFunc {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
backuperNode.Run(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Poll registry for node presence instead of fixed sleep
|
||||
deadline := time.Now().UTC().Add(5 * time.Second)
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
nodes, err := nodesRegistry.GetAvailableNodes()
|
||||
if err == nil {
|
||||
for _, node := range nodes {
|
||||
if node.ID == backuperNode.nodeID {
|
||||
t.Logf("BackuperNode registered in registry: %s", backuperNode.nodeID)
|
||||
|
||||
return func() {
|
||||
cancel()
|
||||
select {
|
||||
case <-done:
|
||||
t.Log("BackuperNode stopped gracefully")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Log("BackuperNode stop timeout")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Fatalf("BackuperNode failed to register in registry within timeout")
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopBackuperNodeForTest stops the BackuperNode by canceling its context.
|
||||
// It waits for the node to unregister from the registry.
|
||||
func StopBackuperNodeForTest(t *testing.T, cancel context.CancelFunc, backuperNode *BackuperNode) {
|
||||
cancel()
|
||||
|
||||
// Wait for node to unregister from registry
|
||||
deadline := time.Now().UTC().Add(2 * time.Second)
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
nodes, err := nodesRegistry.GetAvailableNodes()
|
||||
if err == nil {
|
||||
found := false
|
||||
for _, node := range nodes {
|
||||
if node.ID == backuperNode.nodeID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Logf("BackuperNode unregistered from registry: %s", backuperNode.nodeID)
|
||||
return
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Logf("BackuperNode stop completed for %s", backuperNode.nodeID)
|
||||
}
|
||||
|
||||
func CreateMockNodeInRegistry(nodeID uuid.UUID, throughputMBs int, lastHeartbeat time.Time) error {
|
||||
backupNode := BackupNode{
|
||||
ID: nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: lastHeartbeat,
|
||||
}
|
||||
|
||||
return nodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
|
||||
}
|
||||
|
||||
func UpdateNodeHeartbeatDirectly(
|
||||
nodeID uuid.UUID,
|
||||
throughputMBs int,
|
||||
lastHeartbeat time.Time,
|
||||
) error {
|
||||
backupNode := BackupNode{
|
||||
ID: nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: lastHeartbeat,
|
||||
}
|
||||
|
||||
return nodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
|
||||
}
|
||||
|
||||
func GetNodeFromRegistry(nodeID uuid.UUID) (*BackupNode, error) {
|
||||
nodes, err := nodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
if node.ID == nodeID {
|
||||
return &node, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("node not found")
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
package backups_cancellation
|
||||
|
||||
import (
|
||||
"context"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const backupCancelChannel = "backup:cancel"
|
||||
|
||||
type BackupCancelManager struct {
|
||||
mu sync.RWMutex
|
||||
cancelFuncs map[uuid.UUID]context.CancelFunc
|
||||
pubsub *cache_utils.PubSubManager
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (m *BackupCancelManager) StartSubscription() {
|
||||
ctx := context.Background()
|
||||
|
||||
handler := func(message string) {
|
||||
backupID, err := uuid.Parse(message)
|
||||
if err != nil {
|
||||
m.logger.Error("Invalid backup ID in cancel message", "message", message, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
cancelFunc, exists := m.cancelFuncs[backupID]
|
||||
if exists {
|
||||
cancelFunc()
|
||||
delete(m.cancelFuncs, backupID)
|
||||
m.logger.Info("Cancelled backup via Pub/Sub", "backupID", backupID)
|
||||
}
|
||||
}
|
||||
|
||||
err := m.pubsub.Subscribe(ctx, backupCancelChannel, handler)
|
||||
if err != nil {
|
||||
m.logger.Error("Failed to subscribe to backup cancel channel", "error", err)
|
||||
} else {
|
||||
m.logger.Info("Successfully subscribed to backup cancel channel")
|
||||
}
|
||||
}
|
||||
|
||||
func (m *BackupCancelManager) RegisterBackup(backupID uuid.UUID, cancelFunc context.CancelFunc) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.cancelFuncs[backupID] = cancelFunc
|
||||
m.logger.Debug("Registered backup", "backupID", backupID)
|
||||
}
|
||||
|
||||
func (m *BackupCancelManager) CancelBackup(backupID uuid.UUID) error {
|
||||
ctx := context.Background()
|
||||
|
||||
err := m.pubsub.Publish(ctx, backupCancelChannel, backupID.String())
|
||||
if err != nil {
|
||||
m.logger.Error("Failed to publish cancel message", "backupID", backupID, "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
m.logger.Info("Published backup cancel message", "backupID", backupID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *BackupCancelManager) UnregisterBackup(backupID uuid.UUID) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.cancelFuncs, backupID)
|
||||
m.logger.Debug("Unregistered backup", "backupID", backupID)
|
||||
}
|
||||
@@ -0,0 +1,200 @@
|
||||
package backups_cancellation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_RegisterBackup_BackupRegisteredSuccessfully(t *testing.T) {
|
||||
manager := backupCancelManager
|
||||
|
||||
backupID := uuid.New()
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
manager.RegisterBackup(backupID, cancel)
|
||||
|
||||
manager.mu.RLock()
|
||||
_, exists := manager.cancelFuncs[backupID]
|
||||
manager.mu.RUnlock()
|
||||
assert.True(t, exists, "Backup should be registered")
|
||||
}
|
||||
|
||||
func Test_UnregisterBackup_BackupUnregisteredSuccessfully(t *testing.T) {
|
||||
manager := backupCancelManager
|
||||
|
||||
backupID := uuid.New()
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
manager.RegisterBackup(backupID, cancel)
|
||||
manager.UnregisterBackup(backupID)
|
||||
|
||||
manager.mu.RLock()
|
||||
_, exists := manager.cancelFuncs[backupID]
|
||||
manager.mu.RUnlock()
|
||||
assert.False(t, exists, "Backup should be unregistered")
|
||||
}
|
||||
|
||||
func Test_CancelBackup_OnSameInstance_BackupCancelledViaPubSub(t *testing.T) {
|
||||
manager := backupCancelManager
|
||||
|
||||
backupID := uuid.New()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
cancelled := false
|
||||
var mu sync.Mutex
|
||||
|
||||
wrappedCancel := func() {
|
||||
mu.Lock()
|
||||
cancelled = true
|
||||
mu.Unlock()
|
||||
cancel()
|
||||
}
|
||||
|
||||
manager.RegisterBackup(backupID, wrappedCancel)
|
||||
manager.StartSubscription()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err := manager.CancelBackup(backupID)
|
||||
assert.NoError(t, err, "Cancel should not return error")
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
mu.Lock()
|
||||
wasCancelled := cancelled
|
||||
mu.Unlock()
|
||||
|
||||
assert.True(t, wasCancelled, "Cancel function should have been called")
|
||||
assert.Error(t, ctx.Err(), "Context should be cancelled")
|
||||
}
|
||||
|
||||
func Test_CancelBackup_FromDifferentInstance_BackupCancelledOnRunningInstance(t *testing.T) {
|
||||
manager1 := backupCancelManager
|
||||
manager2 := backupCancelManager
|
||||
|
||||
backupID := uuid.New()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
cancelled := false
|
||||
var mu sync.Mutex
|
||||
|
||||
wrappedCancel := func() {
|
||||
mu.Lock()
|
||||
cancelled = true
|
||||
mu.Unlock()
|
||||
cancel()
|
||||
}
|
||||
|
||||
manager1.RegisterBackup(backupID, wrappedCancel)
|
||||
|
||||
manager1.StartSubscription()
|
||||
manager2.StartSubscription()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err := manager2.CancelBackup(backupID)
|
||||
assert.NoError(t, err, "Cancel should not return error")
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
mu.Lock()
|
||||
wasCancelled := cancelled
|
||||
mu.Unlock()
|
||||
|
||||
assert.True(t, wasCancelled, "Cancel function should have been called on instance 1")
|
||||
assert.Error(t, ctx.Err(), "Context should be cancelled")
|
||||
}
|
||||
|
||||
func Test_CancelBackup_WhenBackupDoesNotExist_NoErrorReturned(t *testing.T) {
|
||||
manager := backupCancelManager
|
||||
|
||||
manager.StartSubscription()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
nonExistentID := uuid.New()
|
||||
err := manager.CancelBackup(nonExistentID)
|
||||
assert.NoError(t, err, "Cancelling non-existent backup should not error")
|
||||
}
|
||||
|
||||
func Test_CancelBackup_WithMultipleBackups_AllBackupsCancelled(t *testing.T) {
|
||||
manager := backupCancelManager
|
||||
|
||||
numBackups := 5
|
||||
backupIDs := make([]uuid.UUID, numBackups)
|
||||
contexts := make([]context.Context, numBackups)
|
||||
cancels := make([]context.CancelFunc, numBackups)
|
||||
cancelledFlags := make([]bool, numBackups)
|
||||
var mu sync.Mutex
|
||||
|
||||
for i := 0; i < numBackups; i++ {
|
||||
backupIDs[i] = uuid.New()
|
||||
contexts[i], cancels[i] = context.WithCancel(context.Background())
|
||||
|
||||
idx := i
|
||||
wrappedCancel := func() {
|
||||
mu.Lock()
|
||||
cancelledFlags[idx] = true
|
||||
mu.Unlock()
|
||||
cancels[idx]()
|
||||
}
|
||||
|
||||
manager.RegisterBackup(backupIDs[i], wrappedCancel)
|
||||
}
|
||||
|
||||
manager.StartSubscription()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
for i := 0; i < numBackups; i++ {
|
||||
err := manager.CancelBackup(backupIDs[i])
|
||||
assert.NoError(t, err, "Cancel should not return error")
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
mu.Lock()
|
||||
for i := 0; i < numBackups; i++ {
|
||||
assert.True(t, cancelledFlags[i], "Backup %d should be cancelled", i)
|
||||
assert.Error(t, contexts[i].Err(), "Context %d should be cancelled", i)
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
func Test_CancelBackup_AfterUnregister_BackupNotCancelled(t *testing.T) {
|
||||
manager := backupCancelManager
|
||||
|
||||
backupID := uuid.New()
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
cancelled := false
|
||||
var mu sync.Mutex
|
||||
|
||||
wrappedCancel := func() {
|
||||
mu.Lock()
|
||||
cancelled = true
|
||||
mu.Unlock()
|
||||
cancel()
|
||||
}
|
||||
|
||||
manager.RegisterBackup(backupID, wrappedCancel)
|
||||
manager.StartSubscription()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
manager.UnregisterBackup(backupID)
|
||||
|
||||
err := manager.CancelBackup(backupID)
|
||||
assert.NoError(t, err, "Cancel should not return error")
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
mu.Lock()
|
||||
wasCancelled := cancelled
|
||||
mu.Unlock()
|
||||
|
||||
assert.False(t, wasCancelled, "Cancel function should not be called after unregister")
|
||||
}
|
||||
25
backend/internal/features/backups/backups/cancellation/di.go
Normal file
25
backend/internal/features/backups/backups/cancellation/di.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package backups_cancellation
|
||||
|
||||
import (
|
||||
"context"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/logger"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var backupCancelManager = &BackupCancelManager{
|
||||
sync.RWMutex{},
|
||||
make(map[uuid.UUID]context.CancelFunc),
|
||||
cache_utils.NewPubSubManager(),
|
||||
logger.GetLogger(),
|
||||
}
|
||||
|
||||
func GetBackupCancelManager() *BackupCancelManager {
|
||||
return backupCancelManager
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
backupCancelManager.StartSubscription()
|
||||
}
|
||||
@@ -1,11 +1,15 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"context"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
"databasus-backend/internal/features/databases"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -170,9 +174,10 @@ func (c *BackupController) CancelBackup(ctx *gin.Context) {
|
||||
// @Description Generate a token for downloading a backup file (valid for 5 minutes)
|
||||
// @Tags backups
|
||||
// @Param id path string true "Backup ID"
|
||||
// @Success 200 {object} GenerateDownloadTokenResponse
|
||||
// @Success 200 {object} backups_download.GenerateDownloadTokenResponse
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 409 {object} map[string]string "Download already in progress"
|
||||
// @Router /backups/{id}/download-token [post]
|
||||
func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
@@ -189,6 +194,15 @@ func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) {
|
||||
|
||||
response, err := c.backupService.GenerateDownloadToken(user, id)
|
||||
if err != nil {
|
||||
if err == backups_download.ErrDownloadAlreadyInProgress {
|
||||
ctx.JSON(
|
||||
http.StatusConflict,
|
||||
gin.H{
|
||||
"error": "Download already in progress for some of backups. Please wait until previous download completed or cancel it",
|
||||
},
|
||||
)
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -198,14 +212,22 @@ func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) {
|
||||
|
||||
// GetFile
|
||||
// @Summary Download a backup file
|
||||
// @Description Download the backup file for the specified backup using a download token
|
||||
// @Description Download the backup file for the specified backup using a download token.
|
||||
// @Description
|
||||
// @Description **Download Concurrency Control:**
|
||||
// @Description - Only one download per user is allowed at a time
|
||||
// @Description - If a download is already in progress, returns 409 Conflict
|
||||
// @Description - Downloads are tracked using cache with 5-second TTL and 3-second heartbeat
|
||||
// @Description - Browser cancellations automatically release the download lock
|
||||
// @Description - Server crashes are handled via automatic cache expiry (5 seconds)
|
||||
// @Tags backups
|
||||
// @Param id path string true "Backup ID"
|
||||
// @Param token query string true "Download token"
|
||||
// @Success 200 {file} file
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 500
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 409 {object} map[string]string "Download already in progress"
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /backups/{id}/file [get]
|
||||
func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
token := ctx.Query("token")
|
||||
@@ -214,7 +236,6 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Get backup ID from URL
|
||||
backupIDParam := ctx.Param("id")
|
||||
backupID, err := uuid.Parse(backupIDParam)
|
||||
if err != nil {
|
||||
@@ -222,13 +243,22 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
downloadToken, err := c.backupService.ValidateDownloadToken(token)
|
||||
downloadToken, rateLimiter, err := c.backupService.ValidateDownloadToken(token)
|
||||
if err != nil {
|
||||
if err == backups_download.ErrDownloadAlreadyInProgress {
|
||||
ctx.JSON(
|
||||
http.StatusConflict,
|
||||
gin.H{
|
||||
"error": "download already in progress for this user. Please wait until previous download completed or cancel it",
|
||||
},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired download token"})
|
||||
return
|
||||
}
|
||||
|
||||
// Verify token is for the requested backup
|
||||
if downloadToken.BackupID != backupID {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired download token"})
|
||||
return
|
||||
@@ -238,18 +268,28 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
downloadToken.BackupID,
|
||||
)
|
||||
if err != nil {
|
||||
c.backupService.UnregisterDownload(downloadToken.UserID)
|
||||
c.backupService.ReleaseDownloadLock(downloadToken.UserID)
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
rateLimitedReader := backups_download.NewRateLimitedReader(fileReader, rateLimiter)
|
||||
|
||||
heartbeatCtx, cancelHeartbeat := context.WithCancel(context.Background())
|
||||
defer func() {
|
||||
if err := fileReader.Close(); err != nil {
|
||||
cancelHeartbeat()
|
||||
c.backupService.UnregisterDownload(downloadToken.UserID)
|
||||
c.backupService.ReleaseDownloadLock(downloadToken.UserID)
|
||||
if err := rateLimitedReader.Close(); err != nil {
|
||||
fmt.Printf("Error closing file reader: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go c.startDownloadHeartbeat(heartbeatCtx, downloadToken.UserID)
|
||||
|
||||
filename := c.generateBackupFilename(backup, database)
|
||||
|
||||
// Set Content-Length for progress tracking
|
||||
if backup.BackupSizeMb > 0 {
|
||||
sizeBytes := int64(backup.BackupSizeMb * 1024 * 1024)
|
||||
ctx.Header("Content-Length", fmt.Sprintf("%d", sizeBytes))
|
||||
@@ -261,13 +301,12 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
fmt.Sprintf("attachment; filename=\"%s\"", filename),
|
||||
)
|
||||
|
||||
_, err = io.Copy(ctx.Writer, fileReader)
|
||||
_, err = io.Copy(ctx.Writer, rateLimitedReader)
|
||||
if err != nil {
|
||||
fmt.Printf("Error streaming file: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Write audit log after successful download
|
||||
c.backupService.WriteAuditLogForDownload(downloadToken.UserID, backup, database)
|
||||
}
|
||||
|
||||
@@ -276,7 +315,7 @@ type MakeBackupRequest struct {
|
||||
}
|
||||
|
||||
func (c *BackupController) generateBackupFilename(
|
||||
backup *Backup,
|
||||
backup *backups_core.Backup,
|
||||
database *databases.Database,
|
||||
) string {
|
||||
// Format timestamp as YYYY-MM-DD_HH-mm-ss
|
||||
@@ -333,3 +372,17 @@ func sanitizeFilename(name string) string {
|
||||
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func (c *BackupController) startDownloadHeartbeat(ctx context.Context, userID uuid.UUID) {
|
||||
ticker := time.NewTicker(backups_download.GetDownloadHeartbeatInterval())
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
c.backupService.RefreshDownloadLock(userID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,7 +18,8 @@ import (
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups/download_token"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
@@ -478,7 +479,7 @@ func Test_GenerateDownloadToken_PermissionsEnforced(t *testing.T) {
|
||||
)
|
||||
|
||||
if tt.expectSuccess {
|
||||
var response GenerateDownloadTokenResponse
|
||||
var response backups_download.GenerateDownloadTokenResponse
|
||||
err := json.Unmarshal(testResp.Body, &response)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, response.Token)
|
||||
@@ -499,7 +500,7 @@ func Test_DownloadBackup_WithValidToken_Success(t *testing.T) {
|
||||
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
// Generate download token
|
||||
var tokenResponse GenerateDownloadTokenResponse
|
||||
var tokenResponse backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
@@ -620,7 +621,7 @@ func Test_DownloadBackup_TokenUsedOnce_CannotReuseToken(t *testing.T) {
|
||||
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
// Generate download token
|
||||
var tokenResponse GenerateDownloadTokenResponse
|
||||
var tokenResponse backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
@@ -683,7 +684,7 @@ func Test_DownloadBackup_WithDifferentBackupToken_Unauthorized(t *testing.T) {
|
||||
backup2 := createTestBackup(database2, owner)
|
||||
|
||||
// Generate token for backup1
|
||||
var tokenResponse GenerateDownloadTokenResponse
|
||||
var tokenResponse backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
@@ -714,7 +715,7 @@ func Test_DownloadBackup_AuditLogWritten(t *testing.T) {
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
// Generate download token
|
||||
var tokenResponse GenerateDownloadTokenResponse
|
||||
var tokenResponse backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
@@ -806,7 +807,7 @@ func Test_DownloadBackup_ProperFilenameForPostgreSQL(t *testing.T) {
|
||||
backup := createTestBackup(database, owner)
|
||||
|
||||
// Generate download token
|
||||
var tokenResponse GenerateDownloadTokenResponse
|
||||
var tokenResponse backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
@@ -897,22 +898,22 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
|
||||
_, err = configService.SaveBackupConfig(config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backup := &Backup{
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: BackupStatusInProgress,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
BackupSizeMb: 0,
|
||||
BackupDurationMs: 0,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
repo := &BackupRepository{}
|
||||
repo := &backups_core.BackupRepository{}
|
||||
err = repo.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Register a cancellable context for the backup
|
||||
GetBackupService().backupContextManager.RegisterBackup(backup.ID, func() {})
|
||||
GetBackupService().backupCancelManager.RegisterBackup(backup.ID, func() {})
|
||||
|
||||
resp := test_utils.MakePostRequest(
|
||||
t,
|
||||
@@ -949,6 +950,189 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
|
||||
assert.True(t, foundCancelLog, "Cancel audit log should be created")
|
||||
}
|
||||
|
||||
func Test_ConcurrentDownloadPrevention(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
var token1Response backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token1Response,
|
||||
)
|
||||
|
||||
var token2Response backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token2Response,
|
||||
)
|
||||
|
||||
downloadInProgress := make(chan bool, 1)
|
||||
downloadComplete := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf(
|
||||
"/api/v1/backups/%s/file?token=%s",
|
||||
backup.ID.String(),
|
||||
token1Response.Token,
|
||||
),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
downloadComplete <- true
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
service := GetBackupService()
|
||||
if !service.IsDownloadInProgress(owner.UserID) {
|
||||
t.Log("Warning: First download completed before we could test concurrency")
|
||||
<-downloadComplete
|
||||
return
|
||||
}
|
||||
|
||||
downloadInProgress <- true
|
||||
|
||||
resp := test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), token2Response.Token),
|
||||
"",
|
||||
http.StatusConflict,
|
||||
)
|
||||
|
||||
var errorResponse map[string]string
|
||||
err := json.Unmarshal(resp.Body, &errorResponse)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, errorResponse["error"], "download already in progress")
|
||||
|
||||
<-downloadComplete
|
||||
<-downloadInProgress
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
var token3Response backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token3Response,
|
||||
)
|
||||
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), token3Response.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
|
||||
t.Log("Database:", database.Name)
|
||||
t.Log(
|
||||
"Successfully prevented concurrent downloads and allowed subsequent downloads after completion",
|
||||
)
|
||||
}
|
||||
|
||||
func Test_GenerateDownloadToken_BlockedWhenDownloadInProgress(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
var token1Response backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token1Response,
|
||||
)
|
||||
|
||||
downloadComplete := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf(
|
||||
"/api/v1/backups/%s/file?token=%s",
|
||||
backup.ID.String(),
|
||||
token1Response.Token,
|
||||
),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
downloadComplete <- true
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
service := GetBackupService()
|
||||
if !service.IsDownloadInProgress(owner.UserID) {
|
||||
t.Log("Warning: First download completed before we could test token generation blocking")
|
||||
<-downloadComplete
|
||||
return
|
||||
}
|
||||
|
||||
resp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusConflict,
|
||||
)
|
||||
|
||||
var errorResponse map[string]string
|
||||
err := json.Unmarshal(resp.Body, &errorResponse)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, errorResponse["error"], "download already in progress")
|
||||
|
||||
<-downloadComplete
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
var token2Response backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token2Response,
|
||||
)
|
||||
|
||||
assert.NotEmpty(t, token2Response.Token)
|
||||
assert.NotEqual(t, token1Response.Token, token2Response.Token)
|
||||
|
||||
t.Log("Database:", database.Name)
|
||||
t.Log(
|
||||
"Successfully blocked token generation during download and allowed generation after completion",
|
||||
)
|
||||
}
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
return CreateTestRouter()
|
||||
}
|
||||
@@ -1038,7 +1222,7 @@ func createTestDatabaseWithBackups(
|
||||
workspace *workspaces_models.Workspace,
|
||||
owner *users_dto.SignInResponseDTO,
|
||||
router *gin.Engine,
|
||||
) (*databases.Database, *Backup) {
|
||||
) (*databases.Database, *backups_core.Backup) {
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
@@ -1064,7 +1248,7 @@ func createTestDatabaseWithBackups(
|
||||
func createTestBackup(
|
||||
database *databases.Database,
|
||||
owner *users_dto.SignInResponseDTO,
|
||||
) *Backup {
|
||||
) *backups_core.Backup {
|
||||
userService := users_services.GetUserService()
|
||||
user, err := userService.GetUserFromToken(owner.Token)
|
||||
if err != nil {
|
||||
@@ -1076,17 +1260,17 @@ func createTestBackup(
|
||||
panic("No storage found for workspace")
|
||||
}
|
||||
|
||||
backup := &Backup{
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storages[0].ID,
|
||||
Status: BackupStatusCompleted,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10.5,
|
||||
BackupDurationMs: 1000,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
repo := &BackupRepository{}
|
||||
repo := &backups_core.BackupRepository{}
|
||||
if err := repo.Save(backup); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@@ -1116,7 +1300,7 @@ func createExpiredDownloadToken(backupID, userID uuid.UUID) string {
|
||||
}
|
||||
|
||||
// Manually update the token to be expired
|
||||
repo := &download_token.DownloadTokenRepository{}
|
||||
repo := &backups_download.DownloadTokenRepository{}
|
||||
downloadToken, err := repo.FindByToken(token)
|
||||
if err != nil || downloadToken == nil {
|
||||
panic(fmt.Sprintf("Failed to find generated token: %v", err))
|
||||
@@ -1130,3 +1314,267 @@ func createExpiredDownloadToken(backupID, userID uuid.UUID) string {
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
func Test_BandwidthThrottling_SingleDownload_Uses75Percent(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
bandwidthManager := backups_download.GetBandwidthManager()
|
||||
initialCount := bandwidthManager.GetActiveDownloadCount()
|
||||
|
||||
var tokenResponse backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&tokenResponse,
|
||||
)
|
||||
|
||||
downloadStarted := make(chan bool, 1)
|
||||
downloadComplete := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf(
|
||||
"/api/v1/backups/%s/file?token=%s",
|
||||
backup.ID.String(),
|
||||
tokenResponse.Token,
|
||||
),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
downloadComplete <- true
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
activeCount := bandwidthManager.GetActiveDownloadCount()
|
||||
if activeCount > initialCount {
|
||||
downloadStarted <- true
|
||||
assert.Equal(t, initialCount+1, activeCount, "Should have one active download")
|
||||
}
|
||||
|
||||
<-downloadComplete
|
||||
if len(downloadStarted) > 0 {
|
||||
<-downloadStarted
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
finalCount := bandwidthManager.GetActiveDownloadCount()
|
||||
assert.Equal(t, initialCount, finalCount, "Download should be unregistered after completion")
|
||||
}
|
||||
|
||||
func Test_BandwidthThrottling_MultipleDownloads_ShareBandwidth(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
owner3 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner1, router)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
owner2,
|
||||
users_enums.WorkspaceRoleMember,
|
||||
owner1.Token,
|
||||
router,
|
||||
)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
owner3,
|
||||
users_enums.WorkspaceRoleMember,
|
||||
owner1.Token,
|
||||
router,
|
||||
)
|
||||
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner1.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
configService := backups_config.GetBackupConfigService()
|
||||
config, err := configService.GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
config.IsBackupsEnabled = true
|
||||
config.StorageID = &storage.ID
|
||||
config.Storage = storage
|
||||
_, err = configService.SaveBackupConfig(config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backup1 := createTestBackup(database, owner1)
|
||||
backup2 := createTestBackup(database, owner2)
|
||||
backup3 := createTestBackup(database, owner3)
|
||||
|
||||
var token1, token2, token3 backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup1.ID.String()),
|
||||
"Bearer "+owner1.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token1,
|
||||
)
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup2.ID.String()),
|
||||
"Bearer "+owner2.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token2,
|
||||
)
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup3.ID.String()),
|
||||
"Bearer "+owner3.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token3,
|
||||
)
|
||||
|
||||
bandwidthManager := backups_download.GetBandwidthManager()
|
||||
initialCount := bandwidthManager.GetActiveDownloadCount()
|
||||
|
||||
complete1 := make(chan bool, 1)
|
||||
complete2 := make(chan bool, 1)
|
||||
complete3 := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup1.ID.String(), token1.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
complete1 <- true
|
||||
}()
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup2.ID.String(), token2.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
complete2 <- true
|
||||
}()
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup3.ID.String(), token3.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
complete3 <- true
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
<-complete1
|
||||
<-complete2
|
||||
<-complete3
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
finalCount := bandwidthManager.GetActiveDownloadCount()
|
||||
assert.Equal(t, initialCount, finalCount, "All downloads should be unregistered")
|
||||
}
|
||||
|
||||
func Test_BandwidthThrottling_DynamicAdjustment(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner1, router)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
owner2,
|
||||
users_enums.WorkspaceRoleMember,
|
||||
owner1.Token,
|
||||
router,
|
||||
)
|
||||
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner1.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
configService := backups_config.GetBackupConfigService()
|
||||
config, err := configService.GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
config.IsBackupsEnabled = true
|
||||
config.StorageID = &storage.ID
|
||||
config.Storage = storage
|
||||
_, err = configService.SaveBackupConfig(config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backup1 := createTestBackup(database, owner1)
|
||||
backup2 := createTestBackup(database, owner2)
|
||||
|
||||
var token1, token2 backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup1.ID.String()),
|
||||
"Bearer "+owner1.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token1,
|
||||
)
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup2.ID.String()),
|
||||
"Bearer "+owner2.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token2,
|
||||
)
|
||||
|
||||
bandwidthManager := backups_download.GetBandwidthManager()
|
||||
initialCount := bandwidthManager.GetActiveDownloadCount()
|
||||
|
||||
complete1 := make(chan bool, 1)
|
||||
complete2 := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup1.ID.String(), token1.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
complete1 <- true
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup2.ID.String(), token2.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
complete2 <- true
|
||||
}()
|
||||
|
||||
<-complete1
|
||||
<-complete2
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
finalCount := bandwidthManager.GetActiveDownloadCount()
|
||||
assert.Equal(t, initialCount, finalCount, "All downloads completed and unregistered")
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package backups
|
||||
package backups_core
|
||||
|
||||
type BackupStatus string
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package backups
|
||||
package backups_core
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -1,4 +1,4 @@
|
||||
package backups
|
||||
package backups_core
|
||||
|
||||
import (
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
@@ -1,4 +1,4 @@
|
||||
package backups
|
||||
package backups_core
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/storage"
|
||||
@@ -1,10 +1,11 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups/download_token"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
"databasus-backend/internal/features/backups/backups/usecases"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -16,48 +17,31 @@ import (
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var backupRepository = &BackupRepository{}
|
||||
var backupRepository = &backups_core.BackupRepository{}
|
||||
|
||||
var backupContextManager = NewBackupContextManager()
|
||||
var backupCancelManager = backups_cancellation.GetBackupCancelManager()
|
||||
|
||||
var backupService = &BackupService{
|
||||
databases.GetDatabaseService(),
|
||||
storages.GetStorageService(),
|
||||
backupRepository,
|
||||
notifiers.GetNotifierService(),
|
||||
notifiers.GetNotifierService(),
|
||||
backups_config.GetBackupConfigService(),
|
||||
encryption_secrets.GetSecretKeyService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
audit_logs.GetAuditLogService(),
|
||||
backupContextManager,
|
||||
download_token.GetDownloadTokenService(),
|
||||
}
|
||||
|
||||
var backupBackgroundService = &BackupBackgroundService{
|
||||
backupService,
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
time.Now().UTC(),
|
||||
logger.GetLogger(),
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
backupRepository: backupRepository,
|
||||
notifierService: notifiers.GetNotifierService(),
|
||||
notificationSender: notifiers.GetNotifierService(),
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
secretKeyService: encryption_secrets.GetSecretKeyService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
createBackupUseCase: usecases.GetCreateBackupUsecase(),
|
||||
logger: logger.GetLogger(),
|
||||
backupRemoveListeners: []backups_core.BackupRemoveListener{},
|
||||
workspaceService: workspaces_services.GetWorkspaceService(),
|
||||
auditLogService: audit_logs.GetAuditLogService(),
|
||||
backupCancelManager: backupCancelManager,
|
||||
downloadTokenService: backups_download.GetDownloadTokenService(),
|
||||
backupSchedulerService: backuping.GetBackupsScheduler(),
|
||||
}
|
||||
|
||||
var backupController = &BackupController{
|
||||
backupService,
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
backups_config.
|
||||
GetBackupConfigService().
|
||||
SetDatabaseStorageChangeListener(backupService)
|
||||
|
||||
databases.GetDatabaseService().AddDbRemoveListener(backupService)
|
||||
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
|
||||
backupService: backupService,
|
||||
}
|
||||
|
||||
func GetBackupService() *BackupService {
|
||||
@@ -68,10 +52,11 @@ func GetBackupController() *BackupController {
|
||||
return backupController
|
||||
}
|
||||
|
||||
func GetBackupBackgroundService() *BackupBackgroundService {
|
||||
return backupBackgroundService
|
||||
}
|
||||
func SetupDependencies() {
|
||||
backups_config.
|
||||
GetBackupConfigService().
|
||||
SetDatabaseStorageChangeListener(backupService)
|
||||
|
||||
func GetDownloadTokenBackgroundService() *download_token.DownloadTokenBackgroundService {
|
||||
return download_token.GetDownloadTokenBackgroundService()
|
||||
databases.GetDatabaseService().AddDbRemoveListener(backupService)
|
||||
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
|
||||
}
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DownloadTokenBackgroundService struct {
|
||||
downloadTokenService *DownloadTokenService
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *DownloadTokenBackgroundService) Run(ctx context.Context) {
|
||||
s.logger.Info("Starting download token cleanup background service")
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
|
||||
s.logger.Error("Failed to clean expired download tokens", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type BandwidthManager struct {
|
||||
mu sync.RWMutex
|
||||
activeDownloads map[uuid.UUID]*activeDownload
|
||||
maxTotalBytesPerSecond int64
|
||||
bytesPerSecondPerDownload int64
|
||||
}
|
||||
|
||||
type activeDownload struct {
|
||||
userID uuid.UUID
|
||||
rateLimiter *RateLimiter
|
||||
}
|
||||
|
||||
func NewBandwidthManager(throughputMBs int) *BandwidthManager {
|
||||
// Use 75% of total throughput
|
||||
maxBytes := int64(throughputMBs) * 1024 * 1024 * 75 / 100
|
||||
|
||||
return &BandwidthManager{
|
||||
activeDownloads: make(map[uuid.UUID]*activeDownload),
|
||||
maxTotalBytesPerSecond: maxBytes,
|
||||
bytesPerSecondPerDownload: maxBytes,
|
||||
}
|
||||
}
|
||||
|
||||
func (bm *BandwidthManager) RegisterDownload(userID uuid.UUID) (*RateLimiter, error) {
|
||||
bm.mu.Lock()
|
||||
defer bm.mu.Unlock()
|
||||
|
||||
if _, exists := bm.activeDownloads[userID]; exists {
|
||||
return nil, fmt.Errorf("download already registered for user %s", userID)
|
||||
}
|
||||
|
||||
rateLimiter := NewRateLimiter(bm.bytesPerSecondPerDownload)
|
||||
|
||||
bm.activeDownloads[userID] = &activeDownload{
|
||||
userID: userID,
|
||||
rateLimiter: rateLimiter,
|
||||
}
|
||||
|
||||
bm.recalculateRates()
|
||||
|
||||
return rateLimiter, nil
|
||||
}
|
||||
|
||||
func (bm *BandwidthManager) UnregisterDownload(userID uuid.UUID) {
|
||||
bm.mu.Lock()
|
||||
defer bm.mu.Unlock()
|
||||
|
||||
delete(bm.activeDownloads, userID)
|
||||
bm.recalculateRates()
|
||||
}
|
||||
|
||||
func (bm *BandwidthManager) GetActiveDownloadCount() int {
|
||||
bm.mu.RLock()
|
||||
defer bm.mu.RUnlock()
|
||||
return len(bm.activeDownloads)
|
||||
}
|
||||
|
||||
func (bm *BandwidthManager) recalculateRates() {
|
||||
activeCount := len(bm.activeDownloads)
|
||||
|
||||
if activeCount == 0 {
|
||||
bm.bytesPerSecondPerDownload = bm.maxTotalBytesPerSecond
|
||||
return
|
||||
}
|
||||
|
||||
newRate := bm.maxTotalBytesPerSecond / int64(activeCount)
|
||||
bm.bytesPerSecondPerDownload = newRate
|
||||
|
||||
for _, download := range bm.activeDownloads {
|
||||
download.rateLimiter.UpdateRate(newRate)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,150 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_BandwidthManager_RegisterSingleDownload(t *testing.T) {
|
||||
throughputMBs := 100
|
||||
manager := NewBandwidthManager(throughputMBs)
|
||||
|
||||
expectedBytesPerSec := int64(100 * 1024 * 1024 * 75 / 100)
|
||||
assert.Equal(t, expectedBytesPerSec, manager.maxTotalBytesPerSecond)
|
||||
assert.Equal(t, expectedBytesPerSec, manager.bytesPerSecondPerDownload)
|
||||
|
||||
userID := uuid.New()
|
||||
rateLimiter, err := manager.RegisterDownload(userID)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, rateLimiter)
|
||||
|
||||
assert.Equal(t, 1, manager.GetActiveDownloadCount())
|
||||
assert.Equal(t, expectedBytesPerSec, manager.bytesPerSecondPerDownload)
|
||||
assert.Equal(t, expectedBytesPerSec, rateLimiter.bytesPerSecond)
|
||||
}
|
||||
|
||||
func Test_BandwidthManager_RegisterMultipleDownloads_BandwidthShared(t *testing.T) {
|
||||
throughputMBs := 100
|
||||
manager := NewBandwidthManager(throughputMBs)
|
||||
|
||||
maxBytes := int64(100 * 1024 * 1024 * 75 / 100)
|
||||
|
||||
user1 := uuid.New()
|
||||
rateLimiter1, err := manager.RegisterDownload(user1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, maxBytes, rateLimiter1.bytesPerSecond)
|
||||
|
||||
user2 := uuid.New()
|
||||
rateLimiter2, err := manager.RegisterDownload(user2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
expectedPerDownload := maxBytes / 2
|
||||
assert.Equal(t, expectedPerDownload, manager.bytesPerSecondPerDownload)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter2.bytesPerSecond)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter2.bytesPerSecond)
|
||||
|
||||
user3 := uuid.New()
|
||||
rateLimiter3, err := manager.RegisterDownload(user3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
expectedPerDownload = maxBytes / 3
|
||||
assert.Equal(t, expectedPerDownload, manager.bytesPerSecondPerDownload)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter2.bytesPerSecond)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter3.bytesPerSecond)
|
||||
assert.Equal(t, 3, manager.GetActiveDownloadCount())
|
||||
}
|
||||
|
||||
func Test_BandwidthManager_UnregisterDownload_BandwidthRebalanced(t *testing.T) {
|
||||
throughputMBs := 100
|
||||
manager := NewBandwidthManager(throughputMBs)
|
||||
|
||||
maxBytes := int64(100 * 1024 * 1024 * 75 / 100)
|
||||
|
||||
user1 := uuid.New()
|
||||
rateLimiter1, _ := manager.RegisterDownload(user1)
|
||||
|
||||
user2 := uuid.New()
|
||||
_, _ = manager.RegisterDownload(user2)
|
||||
|
||||
user3 := uuid.New()
|
||||
rateLimiter3, _ := manager.RegisterDownload(user3)
|
||||
|
||||
assert.Equal(t, 3, manager.GetActiveDownloadCount())
|
||||
expectedPerDownload := maxBytes / 3
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
|
||||
|
||||
manager.UnregisterDownload(user2)
|
||||
|
||||
assert.Equal(t, 2, manager.GetActiveDownloadCount())
|
||||
expectedPerDownload = maxBytes / 2
|
||||
assert.Equal(t, expectedPerDownload, manager.bytesPerSecondPerDownload)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter3.bytesPerSecond)
|
||||
|
||||
manager.UnregisterDownload(user1)
|
||||
|
||||
assert.Equal(t, 1, manager.GetActiveDownloadCount())
|
||||
assert.Equal(t, maxBytes, manager.bytesPerSecondPerDownload)
|
||||
assert.Equal(t, maxBytes, rateLimiter3.bytesPerSecond)
|
||||
|
||||
manager.UnregisterDownload(user3)
|
||||
assert.Equal(t, 0, manager.GetActiveDownloadCount())
|
||||
assert.Equal(t, maxBytes, manager.bytesPerSecondPerDownload)
|
||||
}
|
||||
|
||||
func Test_BandwidthManager_RegisterDuplicateUser_ReturnsError(t *testing.T) {
|
||||
manager := NewBandwidthManager(100)
|
||||
|
||||
userID := uuid.New()
|
||||
_, err := manager.RegisterDownload(userID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = manager.RegisterDownload(userID)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "download already registered")
|
||||
}
|
||||
|
||||
func Test_RateLimiter_TokenBucketBasic(t *testing.T) {
|
||||
bytesPerSec := int64(1024 * 1024)
|
||||
limiter := NewRateLimiter(bytesPerSec)
|
||||
|
||||
assert.Equal(t, bytesPerSec, limiter.bytesPerSecond)
|
||||
assert.Equal(t, bytesPerSec*2, limiter.bucketSize)
|
||||
|
||||
start := time.Now()
|
||||
limiter.Wait(512 * 1024)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.Less(t, elapsed, 100*time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_RateLimiter_UpdateRate(t *testing.T) {
|
||||
limiter := NewRateLimiter(1024 * 1024)
|
||||
|
||||
assert.Equal(t, int64(1024*1024), limiter.bytesPerSecond)
|
||||
|
||||
newRate := int64(2 * 1024 * 1024)
|
||||
limiter.UpdateRate(newRate)
|
||||
|
||||
assert.Equal(t, newRate, limiter.bytesPerSecond)
|
||||
assert.Equal(t, newRate*2, limiter.bucketSize)
|
||||
}
|
||||
|
||||
func Test_RateLimiter_ThrottlesCorrectly(t *testing.T) {
|
||||
bytesPerSec := int64(1024 * 1024)
|
||||
limiter := NewRateLimiter(bytesPerSec)
|
||||
|
||||
limiter.availableTokens = 0
|
||||
|
||||
start := time.Now()
|
||||
limiter.Wait(bytesPerSec / 2)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.GreaterOrEqual(t, elapsed, 400*time.Millisecond)
|
||||
assert.LessOrEqual(t, elapsed, 700*time.Millisecond)
|
||||
}
|
||||
48
backend/internal/features/backups/backups/download/di.go
Normal file
48
backend/internal/features/backups/backups/download/di.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var downloadTokenRepository = &DownloadTokenRepository{}
|
||||
|
||||
var downloadTracker = NewDownloadTracker(cache_utils.GetValkeyClient())
|
||||
|
||||
var bandwidthManager *BandwidthManager
|
||||
var downloadTokenService *DownloadTokenService
|
||||
var downloadTokenBackgroundService *DownloadTokenBackgroundService
|
||||
|
||||
func init() {
|
||||
env := config.GetEnv()
|
||||
throughputMBs := env.NodeNetworkThroughputMBs
|
||||
if throughputMBs == 0 {
|
||||
throughputMBs = 125
|
||||
}
|
||||
bandwidthManager = NewBandwidthManager(throughputMBs)
|
||||
|
||||
downloadTokenService = &DownloadTokenService{
|
||||
downloadTokenRepository,
|
||||
logger.GetLogger(),
|
||||
downloadTracker,
|
||||
bandwidthManager,
|
||||
}
|
||||
|
||||
downloadTokenBackgroundService = &DownloadTokenBackgroundService{
|
||||
downloadTokenService,
|
||||
logger.GetLogger(),
|
||||
}
|
||||
}
|
||||
|
||||
func GetDownloadTokenService() *DownloadTokenService {
|
||||
return downloadTokenService
|
||||
}
|
||||
|
||||
func GetDownloadTokenBackgroundService() *DownloadTokenBackgroundService {
|
||||
return downloadTokenBackgroundService
|
||||
}
|
||||
|
||||
func GetBandwidthManager() *BandwidthManager {
|
||||
return bandwidthManager
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
package backups_download
|
||||
|
||||
import "github.com/google/uuid"
|
||||
|
||||
type GenerateDownloadTokenResponse struct {
|
||||
Token string `json:"token"`
|
||||
Filename string `json:"filename"`
|
||||
BackupID uuid.UUID `json:"backupId"`
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package download_token
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"time"
|
||||
@@ -0,0 +1,101 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RateLimiter struct {
|
||||
mu sync.Mutex
|
||||
bytesPerSecond int64
|
||||
bucketSize int64
|
||||
availableTokens float64
|
||||
lastRefill time.Time
|
||||
}
|
||||
|
||||
func NewRateLimiter(bytesPerSecond int64) *RateLimiter {
|
||||
if bytesPerSecond <= 0 {
|
||||
bytesPerSecond = 1024 * 1024 * 100
|
||||
}
|
||||
|
||||
return &RateLimiter{
|
||||
bytesPerSecond: bytesPerSecond,
|
||||
bucketSize: bytesPerSecond * 2,
|
||||
availableTokens: float64(bytesPerSecond * 2),
|
||||
lastRefill: time.Now().UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) UpdateRate(bytesPerSecond int64) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
if bytesPerSecond <= 0 {
|
||||
bytesPerSecond = 1024 * 1024 * 100
|
||||
}
|
||||
|
||||
rl.bytesPerSecond = bytesPerSecond
|
||||
rl.bucketSize = bytesPerSecond * 2
|
||||
|
||||
if rl.availableTokens > float64(rl.bucketSize) {
|
||||
rl.availableTokens = float64(rl.bucketSize)
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) Wait(bytes int64) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
for {
|
||||
now := time.Now().UTC()
|
||||
elapsed := now.Sub(rl.lastRefill).Seconds()
|
||||
|
||||
tokensToAdd := elapsed * float64(rl.bytesPerSecond)
|
||||
rl.availableTokens += tokensToAdd
|
||||
if rl.availableTokens > float64(rl.bucketSize) {
|
||||
rl.availableTokens = float64(rl.bucketSize)
|
||||
}
|
||||
rl.lastRefill = now
|
||||
|
||||
if rl.availableTokens >= float64(bytes) {
|
||||
rl.availableTokens -= float64(bytes)
|
||||
return
|
||||
}
|
||||
|
||||
tokensNeeded := float64(bytes) - rl.availableTokens
|
||||
waitTime := time.Duration(tokensNeeded/float64(rl.bytesPerSecond)*1000) * time.Millisecond
|
||||
|
||||
if waitTime < time.Millisecond {
|
||||
waitTime = time.Millisecond
|
||||
}
|
||||
|
||||
rl.mu.Unlock()
|
||||
time.Sleep(waitTime)
|
||||
rl.mu.Lock()
|
||||
}
|
||||
}
|
||||
|
||||
type RateLimitedReader struct {
|
||||
reader io.ReadCloser
|
||||
rateLimiter *RateLimiter
|
||||
}
|
||||
|
||||
func NewRateLimitedReader(reader io.ReadCloser, limiter *RateLimiter) *RateLimitedReader {
|
||||
return &RateLimitedReader{
|
||||
reader: reader,
|
||||
rateLimiter: limiter,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RateLimitedReader) Read(p []byte) (n int, err error) {
|
||||
n, err = r.reader.Read(p)
|
||||
if n > 0 {
|
||||
r.rateLimiter.Wait(int64(n))
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *RateLimitedReader) Close() error {
|
||||
return r.reader.Close()
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package download_token
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
105
backend/internal/features/backups/backups/download/service.go
Normal file
105
backend/internal/features/backups/backups/download/service.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type DownloadTokenService struct {
|
||||
repository *DownloadTokenRepository
|
||||
logger *slog.Logger
|
||||
downloadTracker *DownloadTracker
|
||||
bandwidthManager *BandwidthManager
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) Generate(backupID, userID uuid.UUID) (string, error) {
|
||||
if s.downloadTracker.IsDownloadInProgress(userID) {
|
||||
return "", ErrDownloadAlreadyInProgress
|
||||
}
|
||||
|
||||
token := GenerateSecureToken()
|
||||
|
||||
downloadToken := &DownloadToken{
|
||||
Token: token,
|
||||
BackupID: backupID,
|
||||
UserID: userID,
|
||||
ExpiresAt: time.Now().UTC().Add(5 * time.Minute),
|
||||
Used: false,
|
||||
}
|
||||
|
||||
if err := s.repository.Create(downloadToken); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
s.logger.Info("Generated download token", "backupId", backupID, "userId", userID)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) ValidateAndConsume(
|
||||
token string,
|
||||
) (*DownloadToken, *RateLimiter, error) {
|
||||
dt, err := s.repository.FindByToken(token)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if dt == nil {
|
||||
return nil, nil, errors.New("invalid token")
|
||||
}
|
||||
|
||||
if dt.Used {
|
||||
return nil, nil, errors.New("token already used")
|
||||
}
|
||||
|
||||
if time.Now().UTC().After(dt.ExpiresAt) {
|
||||
return nil, nil, errors.New("token expired")
|
||||
}
|
||||
|
||||
if err := s.downloadTracker.AcquireDownloadLock(dt.UserID); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
rateLimiter, err := s.bandwidthManager.RegisterDownload(dt.UserID)
|
||||
if err != nil {
|
||||
s.downloadTracker.ReleaseDownloadLock(dt.UserID)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
dt.Used = true
|
||||
if err := s.repository.Update(dt); err != nil {
|
||||
s.logger.Error("Failed to mark token as used", "error", err)
|
||||
}
|
||||
|
||||
s.logger.Info("Token validated and consumed", "backupId", dt.BackupID, "userId", dt.UserID)
|
||||
return dt, rateLimiter, nil
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) RefreshDownloadLock(userID uuid.UUID) {
|
||||
s.downloadTracker.RefreshDownloadLock(userID)
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) ReleaseDownloadLock(userID uuid.UUID) {
|
||||
s.downloadTracker.ReleaseDownloadLock(userID)
|
||||
s.logger.Info("Released download lock", "userId", userID)
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) IsDownloadInProgress(userID uuid.UUID) bool {
|
||||
return s.downloadTracker.IsDownloadInProgress(userID)
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) UnregisterDownload(userID uuid.UUID) {
|
||||
s.bandwidthManager.UnregisterDownload(userID)
|
||||
s.logger.Info("Unregistered from bandwidth manager", "userId", userID)
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) CleanExpiredTokens() error {
|
||||
now := time.Now().UTC()
|
||||
if err := s.repository.DeleteExpired(now); err != nil {
|
||||
return err
|
||||
}
|
||||
s.logger.Debug("Cleaned expired download tokens")
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/valkey-io/valkey-go"
|
||||
)
|
||||
|
||||
const (
|
||||
downloadLockPrefix = "backup_download_lock:"
|
||||
downloadLockTTL = 5 * time.Second
|
||||
downloadLockValue = "1"
|
||||
downloadHeartbeatDelay = 3 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDownloadAlreadyInProgress = errors.New("download already in progress for this user")
|
||||
)
|
||||
|
||||
type DownloadTracker struct {
|
||||
cache *cache_utils.CacheUtil[string]
|
||||
}
|
||||
|
||||
func NewDownloadTracker(client valkey.Client) *DownloadTracker {
|
||||
return &DownloadTracker{
|
||||
cache: cache_utils.NewCacheUtil[string](client, downloadLockPrefix),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *DownloadTracker) AcquireDownloadLock(userID uuid.UUID) error {
|
||||
key := userID.String()
|
||||
|
||||
existingLock := t.cache.Get(key)
|
||||
if existingLock != nil {
|
||||
return ErrDownloadAlreadyInProgress
|
||||
}
|
||||
|
||||
value := downloadLockValue
|
||||
t.cache.Set(key, &value)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *DownloadTracker) RefreshDownloadLock(userID uuid.UUID) {
|
||||
key := userID.String()
|
||||
value := downloadLockValue
|
||||
t.cache.Set(key, &value)
|
||||
}
|
||||
|
||||
func (t *DownloadTracker) ReleaseDownloadLock(userID uuid.UUID) {
|
||||
key := userID.String()
|
||||
t.cache.Invalidate(key)
|
||||
}
|
||||
|
||||
func (t *DownloadTracker) IsDownloadInProgress(userID uuid.UUID) bool {
|
||||
key := userID.String()
|
||||
existingLock := t.cache.Get(key)
|
||||
return existingLock != nil
|
||||
}
|
||||
|
||||
func GetDownloadHeartbeatInterval() time.Duration {
|
||||
return downloadHeartbeatDelay
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
package download_token
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
"log/slog"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DownloadTokenBackgroundService struct {
|
||||
downloadTokenService *DownloadTokenService
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *DownloadTokenBackgroundService) Run() {
|
||||
s.logger.Info("Starting download token cleanup background service")
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
if config.IsShouldShutdown() {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
|
||||
s.logger.Error("Failed to clean expired download tokens", "error", err)
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Minute)
|
||||
}
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package download_token
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var downloadTokenRepository = &DownloadTokenRepository{}
|
||||
|
||||
var downloadTokenService = &DownloadTokenService{
|
||||
downloadTokenRepository,
|
||||
logger.GetLogger(),
|
||||
}
|
||||
|
||||
var downloadTokenBackgroundService = &DownloadTokenBackgroundService{
|
||||
downloadTokenService,
|
||||
logger.GetLogger(),
|
||||
}
|
||||
|
||||
func GetDownloadTokenService() *DownloadTokenService {
|
||||
return downloadTokenService
|
||||
}
|
||||
|
||||
func GetDownloadTokenBackgroundService() *DownloadTokenBackgroundService {
|
||||
return downloadTokenBackgroundService
|
||||
}
|
||||
@@ -1,69 +0,0 @@
|
||||
package download_token
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type DownloadTokenService struct {
|
||||
repository *DownloadTokenRepository
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) Generate(backupID, userID uuid.UUID) (string, error) {
|
||||
token := GenerateSecureToken()
|
||||
|
||||
downloadToken := &DownloadToken{
|
||||
Token: token,
|
||||
BackupID: backupID,
|
||||
UserID: userID,
|
||||
ExpiresAt: time.Now().UTC().Add(5 * time.Minute),
|
||||
Used: false,
|
||||
}
|
||||
|
||||
if err := s.repository.Create(downloadToken); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
s.logger.Info("Generated download token", "backupId", backupID, "userId", userID)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) ValidateAndConsume(token string) (*DownloadToken, error) {
|
||||
dt, err := s.repository.FindByToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if dt == nil {
|
||||
return nil, errors.New("invalid token")
|
||||
}
|
||||
|
||||
if dt.Used {
|
||||
return nil, errors.New("token already used")
|
||||
}
|
||||
|
||||
if time.Now().UTC().After(dt.ExpiresAt) {
|
||||
return nil, errors.New("token expired")
|
||||
}
|
||||
|
||||
dt.Used = true
|
||||
if err := s.repository.Update(dt); err != nil {
|
||||
s.logger.Error("Failed to mark token as used", "error", err)
|
||||
}
|
||||
|
||||
s.logger.Info("Token validated and consumed", "backupId", dt.BackupID)
|
||||
return dt, nil
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) CleanExpiredTokens() error {
|
||||
now := time.Now().UTC()
|
||||
if err := s.repository.DeleteExpired(now); err != nil {
|
||||
return err
|
||||
}
|
||||
s.logger.Debug("Cleaned expired download tokens")
|
||||
return nil
|
||||
}
|
||||
@@ -1,10 +1,9 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/encryption"
|
||||
"io"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type GetBackupsRequest struct {
|
||||
@@ -14,23 +13,17 @@ type GetBackupsRequest struct {
|
||||
}
|
||||
|
||||
type GetBackupsResponse struct {
|
||||
Backups []*Backup `json:"backups"`
|
||||
Total int64 `json:"total"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
Backups []*backups_core.Backup `json:"backups"`
|
||||
Total int64 `json:"total"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
}
|
||||
|
||||
type GenerateDownloadTokenResponse struct {
|
||||
Token string `json:"token"`
|
||||
Filename string `json:"filename"`
|
||||
BackupID uuid.UUID `json:"backupId"`
|
||||
}
|
||||
|
||||
type decryptionReaderCloser struct {
|
||||
type DecryptionReaderCloser struct {
|
||||
*encryption.DecryptionReader
|
||||
baseReader io.ReadCloser
|
||||
BaseReader io.ReadCloser
|
||||
}
|
||||
|
||||
func (r *decryptionReaderCloser) Close() error {
|
||||
return r.baseReader.Close()
|
||||
func (r *DecryptionReaderCloser) Close() error {
|
||||
return r.BaseReader.Close()
|
||||
}
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups/download_token"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
"databasus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -29,26 +28,27 @@ import (
|
||||
type BackupService struct {
|
||||
databaseService *databases.DatabaseService
|
||||
storageService *storages.StorageService
|
||||
backupRepository *BackupRepository
|
||||
backupRepository *backups_core.BackupRepository
|
||||
notifierService *notifiers.NotifierService
|
||||
notificationSender NotificationSender
|
||||
notificationSender backups_core.NotificationSender
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
secretKeyService *encryption_secrets.SecretKeyService
|
||||
fieldEncryptor util_encryption.FieldEncryptor
|
||||
|
||||
createBackupUseCase CreateBackupUsecase
|
||||
createBackupUseCase backups_core.CreateBackupUsecase
|
||||
|
||||
logger *slog.Logger
|
||||
|
||||
backupRemoveListeners []BackupRemoveListener
|
||||
backupRemoveListeners []backups_core.BackupRemoveListener
|
||||
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
auditLogService *audit_logs.AuditLogService
|
||||
backupContextManager *BackupContextManager
|
||||
downloadTokenService *download_token.DownloadTokenService
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
auditLogService *audit_logs.AuditLogService
|
||||
backupCancelManager *backups_cancellation.BackupCancelManager
|
||||
downloadTokenService *backups_download.DownloadTokenService
|
||||
backupSchedulerService *backuping.BackupsScheduler
|
||||
}
|
||||
|
||||
func (s *BackupService) AddBackupRemoveListener(listener BackupRemoveListener) {
|
||||
func (s *BackupService) AddBackupRemoveListener(listener backups_core.BackupRemoveListener) {
|
||||
s.backupRemoveListeners = append(s.backupRemoveListeners, listener)
|
||||
}
|
||||
|
||||
@@ -91,7 +91,7 @@ func (s *BackupService) MakeBackupWithAuth(
|
||||
return errors.New("insufficient permissions to create backup for this database")
|
||||
}
|
||||
|
||||
go s.MakeBackup(databaseID, true)
|
||||
s.backupSchedulerService.StartBackup(databaseID, true)
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Backup manually initiated for database: %s", database.Name),
|
||||
@@ -175,7 +175,7 @@ func (s *BackupService) DeleteBackup(
|
||||
return errors.New("insufficient permissions to delete backup for this database")
|
||||
}
|
||||
|
||||
if backup.Status == BackupStatusInProgress {
|
||||
if backup.Status == backups_core.BackupStatusInProgress {
|
||||
return errors.New("backup is in progress")
|
||||
}
|
||||
|
||||
@@ -192,260 +192,7 @@ func (s *BackupService) DeleteBackup(
|
||||
return s.deleteBackup(backup)
|
||||
}
|
||||
|
||||
func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
|
||||
database, err := s.databaseService.GetDatabaseByID(databaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get database by ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
lastBackup, err := s.backupRepository.FindLastByDatabaseID(databaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to find last backup by database ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if lastBackup != nil && lastBackup.Status == BackupStatusInProgress {
|
||||
s.logger.Error("Backup is in progress")
|
||||
return
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(databaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if backupConfig.StorageID == nil {
|
||||
s.logger.Error("Backup config storage ID is not defined")
|
||||
return
|
||||
}
|
||||
|
||||
storage, err := s.storageService.GetStorageByID(*backupConfig.StorageID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get storage by ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
backup := &Backup{
|
||||
DatabaseID: databaseID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusInProgress,
|
||||
|
||||
BackupSizeMb: 0,
|
||||
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error("Failed to save backup", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now().UTC()
|
||||
|
||||
backupProgressListener := func(
|
||||
completedMBs float64,
|
||||
) {
|
||||
backup.BackupSizeMb = completedMBs
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error("Failed to update backup progress", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s.backupContextManager.RegisterBackup(backup.ID, cancel)
|
||||
defer s.backupContextManager.UnregisterBackup(backup.ID)
|
||||
|
||||
backupMetadata, err := s.createBackupUseCase.Execute(
|
||||
ctx,
|
||||
backup.ID,
|
||||
backupConfig,
|
||||
database,
|
||||
storage,
|
||||
backupProgressListener,
|
||||
)
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
|
||||
// Check if backup was cancelled (not due to shutdown)
|
||||
isCancelled := strings.Contains(errMsg, "backup cancelled") ||
|
||||
strings.Contains(errMsg, "context canceled") ||
|
||||
errors.Is(err, context.Canceled)
|
||||
isShutdown := strings.Contains(errMsg, "shutdown")
|
||||
|
||||
if isCancelled && !isShutdown {
|
||||
backup.Status = BackupStatusCanceled
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error("Failed to save cancelled backup", "error", err)
|
||||
}
|
||||
|
||||
// Delete partial backup from storage
|
||||
storage, storageErr := s.storageService.GetStorageByID(backup.StorageID)
|
||||
if storageErr == nil {
|
||||
if deleteErr := storage.DeleteFile(s.fieldEncryptor, backup.ID); deleteErr != nil {
|
||||
s.logger.Error(
|
||||
"Failed to delete partial backup file",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
deleteErr,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
backup.FailMessage = &errMsg
|
||||
backup.Status = BackupStatusFailed
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
if updateErr := s.databaseService.SetBackupError(databaseID, errMsg); updateErr != nil {
|
||||
s.logger.Error(
|
||||
"Failed to update database last backup time",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"error",
|
||||
updateErr,
|
||||
)
|
||||
}
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error("Failed to save backup", "error", err)
|
||||
}
|
||||
|
||||
s.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupFailed,
|
||||
&errMsg,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
backup.Status = BackupStatusCompleted
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
// Update backup with encryption metadata if provided
|
||||
if backupMetadata != nil {
|
||||
backup.EncryptionSalt = backupMetadata.EncryptionSalt
|
||||
backup.EncryptionIV = backupMetadata.EncryptionIV
|
||||
backup.Encryption = backupMetadata.Encryption
|
||||
}
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error("Failed to save backup", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Update database last backup time
|
||||
now := time.Now().UTC()
|
||||
if updateErr := s.databaseService.SetLastBackupTime(databaseID, now); updateErr != nil {
|
||||
s.logger.Error(
|
||||
"Failed to update database last backup time",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"error",
|
||||
updateErr,
|
||||
)
|
||||
}
|
||||
|
||||
if backup.Status != BackupStatusCompleted && !isLastTry {
|
||||
return
|
||||
}
|
||||
|
||||
s.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupSuccess,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
func (s *BackupService) SendBackupNotification(
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
backup *Backup,
|
||||
notificationType backups_config.BackupNotificationType,
|
||||
errorMessage *string,
|
||||
) {
|
||||
database, err := s.databaseService.GetDatabaseByID(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
workspace, err := s.workspaceService.GetWorkspaceByID(*database.WorkspaceID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, notifier := range database.Notifiers {
|
||||
if !slices.Contains(
|
||||
backupConfig.SendNotificationsOn,
|
||||
notificationType,
|
||||
) {
|
||||
continue
|
||||
}
|
||||
|
||||
title := ""
|
||||
switch notificationType {
|
||||
case backups_config.NotificationBackupFailed:
|
||||
title = fmt.Sprintf(
|
||||
"❌ Backup failed for database \"%s\" (workspace \"%s\")",
|
||||
database.Name,
|
||||
workspace.Name,
|
||||
)
|
||||
case backups_config.NotificationBackupSuccess:
|
||||
title = fmt.Sprintf(
|
||||
"✅ Backup completed for database \"%s\" (workspace \"%s\")",
|
||||
database.Name,
|
||||
workspace.Name,
|
||||
)
|
||||
}
|
||||
|
||||
message := ""
|
||||
if errorMessage != nil {
|
||||
message = *errorMessage
|
||||
} else {
|
||||
// Format size conditionally
|
||||
var sizeStr string
|
||||
if backup.BackupSizeMb < 1024 {
|
||||
sizeStr = fmt.Sprintf("%.2f MB", backup.BackupSizeMb)
|
||||
} else {
|
||||
sizeGB := backup.BackupSizeMb / 1024
|
||||
sizeStr = fmt.Sprintf("%.2f GB", sizeGB)
|
||||
}
|
||||
|
||||
// Format duration as "0m 0s 0ms"
|
||||
totalMs := backup.BackupDurationMs
|
||||
minutes := totalMs / (1000 * 60)
|
||||
seconds := (totalMs % (1000 * 60)) / 1000
|
||||
durationStr := fmt.Sprintf("%dm %ds", minutes, seconds)
|
||||
|
||||
message = fmt.Sprintf(
|
||||
"Backup completed successfully in %s.\nCompressed backup size: %s",
|
||||
durationStr,
|
||||
sizeStr,
|
||||
)
|
||||
}
|
||||
|
||||
s.notificationSender.SendNotification(
|
||||
¬ifier,
|
||||
title,
|
||||
message,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupService) GetBackup(backupID uuid.UUID) (*Backup, error) {
|
||||
func (s *BackupService) GetBackup(backupID uuid.UUID) (*backups_core.Backup, error) {
|
||||
return s.backupRepository.FindByID(backupID)
|
||||
}
|
||||
|
||||
@@ -475,11 +222,11 @@ func (s *BackupService) CancelBackup(
|
||||
return errors.New("insufficient permissions to cancel backup for this database")
|
||||
}
|
||||
|
||||
if backup.Status != BackupStatusInProgress {
|
||||
if backup.Status != backups_core.BackupStatusInProgress {
|
||||
return errors.New("backup is not in progress")
|
||||
}
|
||||
|
||||
if err := s.backupContextManager.CancelBackup(backupID); err != nil {
|
||||
if err := s.backupCancelManager.CancelBackup(backupID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -499,7 +246,7 @@ func (s *BackupService) CancelBackup(
|
||||
func (s *BackupService) GetBackupFile(
|
||||
user *users_models.User,
|
||||
backupID uuid.UUID,
|
||||
) (io.ReadCloser, *Backup, *databases.Database, error) {
|
||||
) (io.ReadCloser, *backups_core.Backup, *databases.Database, error) {
|
||||
backup, err := s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
@@ -545,7 +292,7 @@ func (s *BackupService) GetBackupFile(
|
||||
return reader, backup, database, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) deleteBackup(backup *Backup) error {
|
||||
func (s *BackupService) deleteBackup(backup *backups_core.Backup) error {
|
||||
for _, listener := range s.backupRemoveListeners {
|
||||
if err := listener.OnBeforeBackupRemove(backup); err != nil {
|
||||
return err
|
||||
@@ -571,7 +318,7 @@ func (s *BackupService) deleteBackup(backup *Backup) error {
|
||||
func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
|
||||
dbBackupsInProgress, err := s.backupRepository.FindByDatabaseIdAndStatus(
|
||||
databaseID,
|
||||
BackupStatusInProgress,
|
||||
backups_core.BackupStatusInProgress,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -680,16 +427,16 @@ func (s *BackupService) getBackupReader(backupID uuid.UUID) (io.ReadCloser, erro
|
||||
|
||||
s.logger.Info("Returning encrypted backup with decryption", "backupId", backupID)
|
||||
|
||||
return &decryptionReaderCloser{
|
||||
decryptionReader,
|
||||
fileReader,
|
||||
return &DecryptionReaderCloser{
|
||||
DecryptionReader: decryptionReader,
|
||||
BaseReader: fileReader,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) GenerateDownloadToken(
|
||||
user *users_models.User,
|
||||
backupID uuid.UUID,
|
||||
) (*GenerateDownloadTokenResponse, error) {
|
||||
) (*backups_download.GenerateDownloadTokenResponse, error) {
|
||||
backup, err := s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -725,20 +472,22 @@ func (s *BackupService) GenerateDownloadToken(
|
||||
database.WorkspaceID,
|
||||
)
|
||||
|
||||
return &GenerateDownloadTokenResponse{
|
||||
return &backups_download.GenerateDownloadTokenResponse{
|
||||
Token: token,
|
||||
Filename: filename,
|
||||
BackupID: backupID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) ValidateDownloadToken(token string) (*download_token.DownloadToken, error) {
|
||||
func (s *BackupService) ValidateDownloadToken(
|
||||
token string,
|
||||
) (*backups_download.DownloadToken, *backups_download.RateLimiter, error) {
|
||||
return s.downloadTokenService.ValidateAndConsume(token)
|
||||
}
|
||||
|
||||
func (s *BackupService) GetBackupFileWithoutAuth(
|
||||
backupID uuid.UUID,
|
||||
) (io.ReadCloser, *Backup, *databases.Database, error) {
|
||||
) (io.ReadCloser, *backups_core.Backup, *databases.Database, error) {
|
||||
backup, err := s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
@@ -759,7 +508,7 @@ func (s *BackupService) GetBackupFileWithoutAuth(
|
||||
|
||||
func (s *BackupService) WriteAuditLogForDownload(
|
||||
userID uuid.UUID,
|
||||
backup *Backup,
|
||||
backup *backups_core.Backup,
|
||||
database *databases.Database,
|
||||
) {
|
||||
s.auditLogService.WriteAuditLog(
|
||||
@@ -773,8 +522,24 @@ func (s *BackupService) WriteAuditLogForDownload(
|
||||
)
|
||||
}
|
||||
|
||||
func (s *BackupService) RefreshDownloadLock(userID uuid.UUID) {
|
||||
s.downloadTokenService.RefreshDownloadLock(userID)
|
||||
}
|
||||
|
||||
func (s *BackupService) ReleaseDownloadLock(userID uuid.UUID) {
|
||||
s.downloadTokenService.ReleaseDownloadLock(userID)
|
||||
}
|
||||
|
||||
func (s *BackupService) IsDownloadInProgress(userID uuid.UUID) bool {
|
||||
return s.downloadTokenService.IsDownloadInProgress(userID)
|
||||
}
|
||||
|
||||
func (s *BackupService) UnregisterDownload(userID uuid.UUID) {
|
||||
s.downloadTokenService.UnregisterDownload(userID)
|
||||
}
|
||||
|
||||
func (s *BackupService) generateBackupFilename(
|
||||
backup *Backup,
|
||||
backup *backups_core.Backup,
|
||||
database *databases.Database,
|
||||
) string {
|
||||
timestamp := backup.CreatedAt.Format("2006-01-02_15-04-05")
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
@@ -58,9 +59,9 @@ func WaitForBackupCompletion(
|
||||
newestBackup := backups[0]
|
||||
t.Logf("WaitForBackupCompletion: newest backup status: %s", newestBackup.Status)
|
||||
|
||||
if newestBackup.Status == BackupStatusCompleted ||
|
||||
newestBackup.Status == BackupStatusFailed ||
|
||||
newestBackup.Status == BackupStatusCanceled {
|
||||
if newestBackup.Status == backups_core.BackupStatusCompleted ||
|
||||
newestBackup.Status == backups_core.BackupStatusFailed ||
|
||||
newestBackup.Status == backups_core.BackupStatusCanceled {
|
||||
t.Logf(
|
||||
"WaitForBackupCompletion: backup finished with status %s",
|
||||
newestBackup.Status,
|
||||
|
||||
@@ -122,6 +122,7 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpArgs(
|
||||
|
||||
if mdb.IsHttps {
|
||||
args = append(args, "--ssl")
|
||||
args = append(args, "--skip-ssl-verify-server-cert")
|
||||
}
|
||||
|
||||
if mdb.Database != nil && *mdb.Database != "" {
|
||||
@@ -265,11 +266,24 @@ func (uc *CreateMariadbBackupUsecase) createTempMyCnfFile(
|
||||
mdbConfig *mariadbtypes.MariadbDatabase,
|
||||
password string,
|
||||
) (string, error) {
|
||||
tempDir, err := os.MkdirTemp(config.GetEnv().TempFolder, "mycnf_"+uuid.New().String())
|
||||
tempFolder := config.GetEnv().TempFolder
|
||||
if err := os.MkdirAll(tempFolder, 0700); err != nil {
|
||||
return "", fmt.Errorf("failed to ensure temp folder exists: %w", err)
|
||||
}
|
||||
if err := os.Chmod(tempFolder, 0700); err != nil {
|
||||
return "", fmt.Errorf("failed to set temp folder permissions: %w", err)
|
||||
}
|
||||
|
||||
tempDir, err := os.MkdirTemp(tempFolder, "mycnf_"+uuid.New().String())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temp directory: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Chmod(tempDir, 0700); err != nil {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
return "", fmt.Errorf("failed to set temp directory permissions: %w", err)
|
||||
}
|
||||
|
||||
myCnfFile := filepath.Join(tempDir, ".my.cnf")
|
||||
|
||||
content := fmt.Sprintf(`[client]
|
||||
@@ -287,6 +301,7 @@ port=%d
|
||||
|
||||
err = os.WriteFile(myCnfFile, []byte(content), 0600)
|
||||
if err != nil {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
return "", fmt.Errorf("failed to write .my.cnf: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -280,11 +280,24 @@ func (uc *CreateMysqlBackupUsecase) createTempMyCnfFile(
|
||||
myConfig *mysqltypes.MysqlDatabase,
|
||||
password string,
|
||||
) (string, error) {
|
||||
tempDir, err := os.MkdirTemp(config.GetEnv().TempFolder, "mycnf_"+uuid.New().String())
|
||||
tempFolder := config.GetEnv().TempFolder
|
||||
if err := os.MkdirAll(tempFolder, 0700); err != nil {
|
||||
return "", fmt.Errorf("failed to ensure temp folder exists: %w", err)
|
||||
}
|
||||
if err := os.Chmod(tempFolder, 0700); err != nil {
|
||||
return "", fmt.Errorf("failed to set temp folder permissions: %w", err)
|
||||
}
|
||||
|
||||
tempDir, err := os.MkdirTemp(tempFolder, "mycnf_"+uuid.New().String())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temp directory: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Chmod(tempDir, 0700); err != nil {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
return "", fmt.Errorf("failed to set temp directory permissions: %w", err)
|
||||
}
|
||||
|
||||
myCnfFile := filepath.Join(tempDir, ".my.cnf")
|
||||
|
||||
content := fmt.Sprintf(`[client]
|
||||
@@ -300,6 +313,7 @@ port=%d
|
||||
|
||||
err = os.WriteFile(myCnfFile, []byte(content), 0600)
|
||||
if err != nil {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
return "", fmt.Errorf("failed to write .my.cnf: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -757,14 +757,28 @@ func (uc *CreatePostgresqlBackupUsecase) createTempPgpassFile(
|
||||
escapedPassword,
|
||||
)
|
||||
|
||||
tempDir, err := os.MkdirTemp("", "pgpass")
|
||||
tempFolder := config.GetEnv().TempFolder
|
||||
if err := os.MkdirAll(tempFolder, 0700); err != nil {
|
||||
return "", fmt.Errorf("failed to ensure temp folder exists: %w", err)
|
||||
}
|
||||
if err := os.Chmod(tempFolder, 0700); err != nil {
|
||||
return "", fmt.Errorf("failed to set temp folder permissions: %w", err)
|
||||
}
|
||||
|
||||
tempDir, err := os.MkdirTemp(tempFolder, "pgpass_"+uuid.New().String())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temporary directory: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Chmod(tempDir, 0700); err != nil {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
return "", fmt.Errorf("failed to set temporary directory permissions: %w", err)
|
||||
}
|
||||
|
||||
pgpassFile := filepath.Join(tempDir, ".pgpass")
|
||||
err = os.WriteFile(pgpassFile, []byte(pgpassContent), 0600)
|
||||
if err != nil {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
return "", fmt.Errorf("failed to write temporary .pgpass file: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package mariadb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -14,7 +15,7 @@ import (
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/tools"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
@@ -398,8 +399,16 @@ func HasPrivilege(privileges, priv string) bool {
|
||||
|
||||
func (m *MariadbDatabase) buildDSN(password string, database string) string {
|
||||
tlsConfig := "false"
|
||||
|
||||
if m.IsHttps {
|
||||
tlsConfig = "skip-verify"
|
||||
err := mysql.RegisterTLSConfig("mariadb-skip-verify", &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
if err != nil {
|
||||
// Config might already be registered, which is fine
|
||||
_ = err
|
||||
}
|
||||
tlsConfig = "mariadb-skip-verify"
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
@@ -506,11 +515,13 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
|
||||
hasProcess := false
|
||||
hasAllPrivileges := false
|
||||
|
||||
escapedDB := strings.ReplaceAll(database, "_", "\\_")
|
||||
dbPattern := regexp.MustCompile(
|
||||
fmt.Sprintf("(?i)ON\\s+[`'\"]?(%s|\\*)[`'\"]?\\.\\*", regexp.QuoteMeta(escapedDB)),
|
||||
dbPatternStr := fmt.Sprintf(
|
||||
`(?i)ON\s+[\x60'"]?%s[\x60'"]?\s*\.\s*\*`,
|
||||
regexp.QuoteMeta(database),
|
||||
)
|
||||
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\.\*`)
|
||||
dbPattern := regexp.MustCompile(dbPatternStr)
|
||||
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\s*\.\s*\*`)
|
||||
allPrivilegesPattern := regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`)
|
||||
|
||||
for rows.Next() {
|
||||
var grant string
|
||||
@@ -518,23 +529,26 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
|
||||
return "", fmt.Errorf("failed to scan grant: %w", err)
|
||||
}
|
||||
|
||||
if regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`).MatchString(grant) {
|
||||
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
|
||||
hasAllPrivileges = true
|
||||
}
|
||||
isRelevantGrant := globalPattern.MatchString(grant) || dbPattern.MatchString(grant)
|
||||
|
||||
if allPrivilegesPattern.MatchString(grant) && isRelevantGrant {
|
||||
hasAllPrivileges = true
|
||||
}
|
||||
|
||||
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
|
||||
if isRelevantGrant {
|
||||
for _, priv := range backupPrivileges {
|
||||
if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) {
|
||||
privPattern := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(priv) + `\b`)
|
||||
if privPattern.MatchString(grant) {
|
||||
detectedPrivileges[priv] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if globalPattern.MatchString(grant) &&
|
||||
regexp.MustCompile(`(?i)\bPROCESS\b`).MatchString(grant) {
|
||||
hasProcess = true
|
||||
if globalPattern.MatchString(grant) {
|
||||
processPattern := regexp.MustCompile(`(?i)\bPROCESS\b`)
|
||||
if processPattern.MatchString(grant) {
|
||||
hasProcess = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -562,9 +576,9 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
|
||||
}
|
||||
|
||||
// checkBackupPermissions verifies the user has sufficient privileges for mariadb-dump backup.
|
||||
// Required: SELECT, SHOW VIEW, PROCESS. Optional: LOCK TABLES, TRIGGER, EVENT.
|
||||
// Required: SELECT, SHOW VIEW
|
||||
func checkBackupPermissions(privileges string) error {
|
||||
requiredPrivileges := []string{"SELECT", "SHOW VIEW", "PROCESS"}
|
||||
requiredPrivileges := []string{"SELECT", "SHOW VIEW"}
|
||||
|
||||
var missingPrivileges []string
|
||||
for _, priv := range requiredPrivileges {
|
||||
@@ -575,7 +589,7 @@ func checkBackupPermissions(privileges string) error {
|
||||
|
||||
if len(missingPrivileges) > 0 {
|
||||
return fmt.Errorf(
|
||||
"insufficient permissions for backup. Missing: %s. Required: SELECT, SHOW VIEW, PROCESS",
|
||||
"insufficient permissions for backup. Missing: %s. Required: SELECT, SHOW VIEW",
|
||||
strings.Join(missingPrivileges, ", "),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -537,6 +537,163 @@ func Test_ReadOnlyUser_CannotDropOrAlterTables(t *testing.T) {
|
||||
dropUserSafe(container.DB, username)
|
||||
}
|
||||
|
||||
func Test_TestConnection_DatabaseSpecificPrivilegesWithGlobalProcess_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MariadbVersion
|
||||
port string
|
||||
}{
|
||||
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
|
||||
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
|
||||
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
|
||||
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
|
||||
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
|
||||
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
|
||||
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
|
||||
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
|
||||
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
|
||||
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
|
||||
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMariadbContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`DROP TABLE IF EXISTS privilege_test`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`CREATE TABLE privilege_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`INSERT INTO privilege_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
specificUsername := fmt.Sprintf("spec_%s", uuid.New().String()[:8])
|
||||
specificPassword := "specificpass123"
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
specificUsername,
|
||||
specificPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"GRANT SELECT, SHOW VIEW ON %s.* TO '%s'@'%%'",
|
||||
container.Database,
|
||||
specificUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"GRANT PROCESS ON *.* TO '%s'@'%%'",
|
||||
specificUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer dropUserSafe(container.DB, specificUsername)
|
||||
|
||||
mariadbModel := &MariadbDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: specificUsername,
|
||||
Password: specificPassword,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mariadbModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_TestConnection_DatabaseWithUnderscores_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToMariadbContainer(t, env.TestMariadb1011Port, tools.MariadbVersion1011)
|
||||
defer container.DB.Close()
|
||||
|
||||
underscoreDbName := "test_db_name"
|
||||
|
||||
_, err := container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", underscoreDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName))
|
||||
}()
|
||||
|
||||
underscoreDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
container.Username, container.Password, container.Host, container.Port, underscoreDbName)
|
||||
underscoreDB, err := sqlx.Connect("mysql", underscoreDSN)
|
||||
assert.NoError(t, err)
|
||||
defer underscoreDB.Close()
|
||||
|
||||
_, err = underscoreDB.Exec(`
|
||||
CREATE TABLE underscore_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(`INSERT INTO underscore_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
underscoreUsername := fmt.Sprintf("under%s", uuid.New().String()[:8])
|
||||
underscorePassword := "underscorepass123"
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
underscoreUsername,
|
||||
underscorePassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"GRANT SELECT, SHOW VIEW ON `%s`.* TO '%s'@'%%'",
|
||||
underscoreDbName,
|
||||
underscoreUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer dropUserSafe(underscoreDB, underscoreUsername)
|
||||
|
||||
mariadbModel := &MariadbDatabase{
|
||||
Version: tools.MariadbVersion1011,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: underscoreUsername,
|
||||
Password: underscorePassword,
|
||||
Database: &underscoreDbName,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mariadbModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
type MariadbContainer struct {
|
||||
Host string
|
||||
Port int
|
||||
|
||||
@@ -2,6 +2,7 @@ package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -14,7 +15,7 @@ import (
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/tools"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
@@ -399,18 +400,30 @@ func HasPrivilege(privileges, priv string) bool {
|
||||
|
||||
func (m *MysqlDatabase) buildDSN(password string, database string) string {
|
||||
tlsConfig := "false"
|
||||
allowCleartext := ""
|
||||
|
||||
if m.IsHttps {
|
||||
tlsConfig = "skip-verify"
|
||||
err := mysql.RegisterTLSConfig("mysql-skip-verify", &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
if err != nil {
|
||||
// Config might already be registered, which is fine
|
||||
_ = err
|
||||
}
|
||||
|
||||
tlsConfig = "mysql-skip-verify"
|
||||
allowCleartext = "&allowCleartextPasswords=1"
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"%s:%s@tcp(%s:%d)/%s?parseTime=true&timeout=15s&tls=%s&charset=utf8mb4",
|
||||
"%s:%s@tcp(%s:%d)/%s?parseTime=true&timeout=15s&tls=%s&charset=utf8mb4%s",
|
||||
m.Username,
|
||||
password,
|
||||
m.Host,
|
||||
m.Port,
|
||||
database,
|
||||
tlsConfig,
|
||||
allowCleartext,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -476,11 +489,13 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
|
||||
hasProcess := false
|
||||
hasAllPrivileges := false
|
||||
|
||||
escapedDB := strings.ReplaceAll(database, "_", "\\_")
|
||||
dbPattern := regexp.MustCompile(
|
||||
fmt.Sprintf("(?i)ON\\s+[`'\"]?(%s|\\*)[`'\"]?\\.\\*", regexp.QuoteMeta(escapedDB)),
|
||||
dbPatternStr := fmt.Sprintf(
|
||||
`(?i)ON\s+[\x60'"]?%s[\x60'"]?\s*\.\s*\*`,
|
||||
regexp.QuoteMeta(database),
|
||||
)
|
||||
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\.\*`)
|
||||
dbPattern := regexp.MustCompile(dbPatternStr)
|
||||
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\s*\.\s*\*`)
|
||||
allPrivilegesPattern := regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`)
|
||||
|
||||
for rows.Next() {
|
||||
var grant string
|
||||
@@ -488,23 +503,26 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
|
||||
return "", fmt.Errorf("failed to scan grant: %w", err)
|
||||
}
|
||||
|
||||
if regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`).MatchString(grant) {
|
||||
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
|
||||
hasAllPrivileges = true
|
||||
}
|
||||
isRelevantGrant := globalPattern.MatchString(grant) || dbPattern.MatchString(grant)
|
||||
|
||||
if allPrivilegesPattern.MatchString(grant) && isRelevantGrant {
|
||||
hasAllPrivileges = true
|
||||
}
|
||||
|
||||
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
|
||||
if isRelevantGrant {
|
||||
for _, priv := range backupPrivileges {
|
||||
if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) {
|
||||
privPattern := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(priv) + `\b`)
|
||||
if privPattern.MatchString(grant) {
|
||||
detectedPrivileges[priv] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if globalPattern.MatchString(grant) &&
|
||||
regexp.MustCompile(`(?i)\bPROCESS\b`).MatchString(grant) {
|
||||
hasProcess = true
|
||||
if globalPattern.MatchString(grant) {
|
||||
processPattern := regexp.MustCompile(`(?i)\bPROCESS\b`)
|
||||
if processPattern.MatchString(grant) {
|
||||
hasProcess = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -532,9 +550,9 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
|
||||
}
|
||||
|
||||
// checkBackupPermissions verifies the user has sufficient privileges for mysqldump backup.
|
||||
// Required: SELECT, SHOW VIEW, PROCESS. Optional: LOCK TABLES, TRIGGER, EVENT.
|
||||
// Required: SELECT, SHOW VIEW
|
||||
func checkBackupPermissions(privileges string) error {
|
||||
requiredPrivileges := []string{"SELECT", "SHOW VIEW", "PROCESS"}
|
||||
requiredPrivileges := []string{"SELECT", "SHOW VIEW"}
|
||||
|
||||
var missingPrivileges []string
|
||||
for _, priv := range requiredPrivileges {
|
||||
@@ -545,7 +563,7 @@ func checkBackupPermissions(privileges string) error {
|
||||
|
||||
if len(missingPrivileges) > 0 {
|
||||
return fmt.Errorf(
|
||||
"insufficient permissions for backup. Missing: %s. Required: SELECT, SHOW VIEW, PROCESS",
|
||||
"insufficient permissions for backup. Missing: %s. Required: SELECT, SHOW VIEW",
|
||||
strings.Join(missingPrivileges, ", "),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -518,6 +518,162 @@ func Test_ReadOnlyUser_CannotDropOrAlterTables(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_TestConnection_DatabaseSpecificPrivilegesWithGlobalProcess_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MysqlVersion
|
||||
port string
|
||||
}{
|
||||
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port},
|
||||
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port},
|
||||
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port},
|
||||
{"MySQL 9", tools.MysqlVersion9, env.TestMysql90Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMysqlContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`DROP TABLE IF EXISTS privilege_test`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`CREATE TABLE privilege_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`INSERT INTO privilege_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
specificUsername := fmt.Sprintf("specific_%s", uuid.New().String()[:8])
|
||||
specificPassword := "specificpass123"
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
specificUsername,
|
||||
specificPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"GRANT SELECT, SHOW VIEW ON %s.* TO '%s'@'%%'",
|
||||
container.Database,
|
||||
specificUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"GRANT PROCESS ON *.* TO '%s'@'%%'",
|
||||
specificUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(
|
||||
fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", specificUsername),
|
||||
)
|
||||
}()
|
||||
|
||||
mysqlModel := &MysqlDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: specificUsername,
|
||||
Password: specificPassword,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mysqlModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_TestConnection_DatabaseWithUnderscores_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToMysqlContainer(t, env.TestMysql80Port, tools.MysqlVersion80)
|
||||
defer container.DB.Close()
|
||||
|
||||
underscoreDbName := "test_db_name"
|
||||
|
||||
_, err := container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", underscoreDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName))
|
||||
}()
|
||||
|
||||
underscoreDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
container.Username, container.Password, container.Host, container.Port, underscoreDbName)
|
||||
underscoreDB, err := sqlx.Connect("mysql", underscoreDSN)
|
||||
assert.NoError(t, err)
|
||||
defer underscoreDB.Close()
|
||||
|
||||
_, err = underscoreDB.Exec(`
|
||||
CREATE TABLE underscore_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(`INSERT INTO underscore_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
underscoreUsername := fmt.Sprintf("under_%s", uuid.New().String()[:8])
|
||||
underscorePassword := "underscorepass123"
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
underscoreUsername,
|
||||
underscorePassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"GRANT SELECT, SHOW VIEW ON `%s`.* TO '%s'@'%%'",
|
||||
underscoreDbName,
|
||||
underscoreUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = underscoreDB.Exec(fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", underscoreUsername))
|
||||
}()
|
||||
|
||||
mysqlModel := &MysqlDatabase{
|
||||
Version: tools.MysqlVersion80,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: underscoreUsername,
|
||||
Password: underscorePassword,
|
||||
Database: &underscoreDbName,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mysqlModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
type MysqlContainer struct {
|
||||
Host string
|
||||
Port int
|
||||
|
||||
@@ -671,7 +671,7 @@ func testSingleDatabaseConnection(
|
||||
postgresDb.Version = detectedVersion
|
||||
|
||||
// Verify user has sufficient permissions for backup operations
|
||||
if err := checkBackupPermissions(ctx, conn, *postgresDb.Database); err != nil {
|
||||
if err := checkBackupPermissions(ctx, conn, postgresDb.IncludeSchemas); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -707,7 +707,12 @@ func detectDatabaseVersion(ctx context.Context, conn *pgx.Conn) (tools.Postgresq
|
||||
|
||||
// checkBackupPermissions verifies the user has sufficient privileges for pg_dump backup.
|
||||
// Required privileges: CONNECT on database, USAGE on schemas, SELECT on tables.
|
||||
func checkBackupPermissions(ctx context.Context, conn *pgx.Conn, dbName string) error {
|
||||
// If includeSchemas is specified, only checks permissions on those schemas.
|
||||
func checkBackupPermissions(
|
||||
ctx context.Context,
|
||||
conn *pgx.Conn,
|
||||
includeSchemas []string,
|
||||
) error {
|
||||
var missingPrivileges []string
|
||||
|
||||
// Check CONNECT privilege on database
|
||||
@@ -723,14 +728,29 @@ func checkBackupPermissions(ctx context.Context, conn *pgx.Conn, dbName string)
|
||||
|
||||
// Check USAGE privilege on at least one non-system schema
|
||||
var schemaCount int
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_namespace n
|
||||
WHERE has_schema_privilege(current_user, n.nspname, 'USAGE')
|
||||
AND n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
|
||||
AND n.nspname NOT LIKE 'pg_temp_%'
|
||||
AND n.nspname NOT LIKE 'pg_toast_temp_%'
|
||||
`).Scan(&schemaCount)
|
||||
if len(includeSchemas) > 0 {
|
||||
// Check only the specified schemas
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_namespace n
|
||||
WHERE has_schema_privilege(current_user, n.nspname, 'USAGE')
|
||||
AND n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
|
||||
AND n.nspname NOT LIKE 'pg_temp_%'
|
||||
AND n.nspname NOT LIKE 'pg_toast_temp_%'
|
||||
AND n.nspname = ANY($1::text[])
|
||||
`, includeSchemas).Scan(&schemaCount)
|
||||
} else {
|
||||
// Check all non-system schemas
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_namespace n
|
||||
WHERE has_schema_privilege(current_user, n.nspname, 'USAGE')
|
||||
AND n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
|
||||
AND n.nspname NOT LIKE 'pg_temp_%'
|
||||
AND n.nspname NOT LIKE 'pg_toast_temp_%'
|
||||
`).Scan(&schemaCount)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot check schema privileges: %w", err)
|
||||
}
|
||||
@@ -741,11 +761,28 @@ func checkBackupPermissions(ctx context.Context, conn *pgx.Conn, dbName string)
|
||||
// Check SELECT privilege on at least one table (if tables exist)
|
||||
// Use pg_tables from pg_catalog which shows all tables regardless of user privileges
|
||||
var tableCount int
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_catalog.pg_tables t
|
||||
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
|
||||
`).Scan(&tableCount)
|
||||
|
||||
if len(includeSchemas) > 0 {
|
||||
// Check only tables in the specified schemas
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_catalog.pg_tables t
|
||||
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
|
||||
AND t.schemaname NOT LIKE 'pg_temp_%'
|
||||
AND t.schemaname NOT LIKE 'pg_toast_temp_%'
|
||||
AND t.schemaname = ANY($1::text[])
|
||||
`, includeSchemas).Scan(&tableCount)
|
||||
} else {
|
||||
// Check all tables in non-system schemas
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_catalog.pg_tables t
|
||||
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
|
||||
AND t.schemaname NOT LIKE 'pg_temp_%'
|
||||
AND t.schemaname NOT LIKE 'pg_toast_temp_%'
|
||||
`).Scan(&tableCount)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot check table count: %w", err)
|
||||
}
|
||||
@@ -753,12 +790,30 @@ func checkBackupPermissions(ctx context.Context, conn *pgx.Conn, dbName string)
|
||||
if tableCount > 0 {
|
||||
// Check if user has SELECT on at least one of these tables
|
||||
var selectableTableCount int
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_catalog.pg_tables t
|
||||
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
|
||||
AND has_table_privilege(current_user, quote_ident(t.schemaname) || '.' || quote_ident(t.tablename), 'SELECT')
|
||||
`).Scan(&selectableTableCount)
|
||||
|
||||
if len(includeSchemas) > 0 {
|
||||
// Check only tables in the specified schemas
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_catalog.pg_tables t
|
||||
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
|
||||
AND t.schemaname NOT LIKE 'pg_temp_%'
|
||||
AND t.schemaname NOT LIKE 'pg_toast_temp_%'
|
||||
AND t.schemaname = ANY($1::text[])
|
||||
AND has_table_privilege(current_user, quote_ident(t.schemaname) || '.' || quote_ident(t.tablename), 'SELECT')
|
||||
`, includeSchemas).Scan(&selectableTableCount)
|
||||
} else {
|
||||
// Check all tables in non-system schemas
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_catalog.pg_tables t
|
||||
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
|
||||
AND t.schemaname NOT LIKE 'pg_temp_%'
|
||||
AND t.schemaname NOT LIKE 'pg_toast_temp_%'
|
||||
AND has_table_privilege(current_user, quote_ident(t.schemaname) || '.' || quote_ident(t.tablename), 'SELECT')
|
||||
`).Scan(&selectableTableCount)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot check SELECT privileges: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package healthcheck_attempt
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
"context"
|
||||
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
|
||||
"log/slog"
|
||||
"time"
|
||||
@@ -13,18 +13,19 @@ type HealthcheckAttemptBackgroundService struct {
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *HealthcheckAttemptBackgroundService) Run() {
|
||||
func (s *HealthcheckAttemptBackgroundService) Run(ctx context.Context) {
|
||||
// first healthcheck immediately
|
||||
s.checkDatabases()
|
||||
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
if config.IsShouldShutdown() {
|
||||
break
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.checkDatabases()
|
||||
}
|
||||
|
||||
s.checkDatabases()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -675,6 +675,10 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
WebhookNotifier: &webhook_notifier.WebhookNotifier{
|
||||
WebhookURL: "https://webhook.example.com/test",
|
||||
WebhookMethod: webhook_notifier.WebhookMethodPOST,
|
||||
Headers: []webhook_notifier.WebhookHeader{
|
||||
{Key: "Authorization", Value: "Bearer my-secret-token"},
|
||||
{Key: "X-Custom-Header", Value: "custom-value"},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
@@ -687,14 +691,40 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
WebhookNotifier: &webhook_notifier.WebhookNotifier{
|
||||
WebhookURL: "https://webhook.example.com/updated",
|
||||
WebhookMethod: webhook_notifier.WebhookMethodGET,
|
||||
Headers: []webhook_notifier.WebhookHeader{
|
||||
{Key: "Authorization", Value: "Bearer updated-token"},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
|
||||
// No sensitive data to verify for webhook
|
||||
assert.NotEmpty(
|
||||
t,
|
||||
notifier.WebhookNotifier.WebhookURL,
|
||||
"WebhookURL should be visible",
|
||||
)
|
||||
// Verify header values are encrypted in DB
|
||||
assert.True(
|
||||
t,
|
||||
isEncrypted(notifier.WebhookNotifier.Headers[0].Value),
|
||||
"Header value should be encrypted in DB",
|
||||
)
|
||||
decrypted := decryptField(
|
||||
t,
|
||||
notifier.ID,
|
||||
notifier.WebhookNotifier.Headers[0].Value,
|
||||
)
|
||||
assert.Equal(t, "Bearer updated-token", decrypted)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
|
||||
// No sensitive data to hide for webhook
|
||||
assert.NotEmpty(
|
||||
t,
|
||||
notifier.WebhookNotifier.WebhookURL,
|
||||
"WebhookURL should be visible",
|
||||
)
|
||||
for _, header := range notifier.WebhookNotifier.Headers {
|
||||
assert.Empty(t, header.Value, "Header value should be hidden")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -905,7 +935,7 @@ func Test_CreateNotifier_AllSensitiveFieldsEncryptedInDB(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Webhook Notifier - WebhookURL encrypted",
|
||||
name: "Webhook Notifier - Header values encrypted, URL not encrypted",
|
||||
createNotifier: func(workspaceID uuid.UUID) *Notifier {
|
||||
return &Notifier{
|
||||
WorkspaceID: workspaceID,
|
||||
@@ -914,17 +944,48 @@ func Test_CreateNotifier_AllSensitiveFieldsEncryptedInDB(t *testing.T) {
|
||||
WebhookNotifier: &webhook_notifier.WebhookNotifier{
|
||||
WebhookURL: "https://webhook.example.com/test456",
|
||||
WebhookMethod: webhook_notifier.WebhookMethodPOST,
|
||||
Headers: []webhook_notifier.WebhookHeader{
|
||||
{Key: "Authorization", Value: "Bearer secret-token-12345"},
|
||||
{Key: "X-API-Key", Value: "api-key-67890"},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
|
||||
assert.True(
|
||||
assert.False(
|
||||
t,
|
||||
isEncrypted(notifier.WebhookNotifier.WebhookURL),
|
||||
"WebhookURL should be encrypted",
|
||||
"WebhookURL should NOT be encrypted",
|
||||
)
|
||||
decrypted := decryptField(t, notifier.ID, notifier.WebhookNotifier.WebhookURL)
|
||||
assert.Equal(t, "https://webhook.example.com/test456", decrypted)
|
||||
assert.Equal(
|
||||
t,
|
||||
"https://webhook.example.com/test456",
|
||||
notifier.WebhookNotifier.WebhookURL,
|
||||
)
|
||||
|
||||
assert.True(
|
||||
t,
|
||||
isEncrypted(notifier.WebhookNotifier.Headers[0].Value),
|
||||
"Header value should be encrypted",
|
||||
)
|
||||
decrypted1 := decryptField(
|
||||
t,
|
||||
notifier.ID,
|
||||
notifier.WebhookNotifier.Headers[0].Value,
|
||||
)
|
||||
assert.Equal(t, "Bearer secret-token-12345", decrypted1)
|
||||
|
||||
assert.True(
|
||||
t,
|
||||
isEncrypted(notifier.WebhookNotifier.Headers[1].Value),
|
||||
"Header value should be encrypted",
|
||||
)
|
||||
decrypted2 := decryptField(
|
||||
t,
|
||||
notifier.ID,
|
||||
notifier.WebhookNotifier.Headers[1].Value,
|
||||
)
|
||||
assert.Equal(t, "api-key-67890", decrypted2)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -21,6 +21,10 @@ type WebhookHeader struct {
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
// Before both WebhookURL, BodyTemplate and HeadersJSON were considered
|
||||
// as sensetive data and it was causing issues. Now only headers values
|
||||
// considered as sensetive data, but we try to decrypt webhook URL and
|
||||
// body template for backward combability
|
||||
type WebhookNotifier struct {
|
||||
NotifierID uuid.UUID `json:"notifierId" gorm:"primaryKey;column:notifier_id"`
|
||||
WebhookURL string `json:"webhookUrl" gorm:"not null;column:webhook_url"`
|
||||
@@ -58,6 +62,20 @@ func (t *WebhookNotifier) AfterFind(_ *gorm.DB) error {
|
||||
}
|
||||
}
|
||||
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
|
||||
if t.WebhookURL != "" {
|
||||
if decrypted, err := encryptor.Decrypt(t.NotifierID, t.WebhookURL); err == nil {
|
||||
t.WebhookURL = decrypted
|
||||
}
|
||||
}
|
||||
|
||||
if t.BodyTemplate != nil && *t.BodyTemplate != "" {
|
||||
if decrypted, err := encryptor.Decrypt(t.NotifierID, *t.BodyTemplate); err == nil {
|
||||
t.BodyTemplate = &decrypted
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -79,22 +97,24 @@ func (t *WebhookNotifier) Send(
|
||||
heading string,
|
||||
message string,
|
||||
) error {
|
||||
webhookURL, err := encryptor.Decrypt(t.NotifierID, t.WebhookURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
|
||||
if err := t.decryptHeadersForSending(encryptor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch t.WebhookMethod {
|
||||
case WebhookMethodGET:
|
||||
return t.sendGET(webhookURL, heading, message, logger)
|
||||
return t.sendGET(t.WebhookURL, heading, message, logger)
|
||||
case WebhookMethodPOST:
|
||||
return t.sendPOST(webhookURL, heading, message, logger)
|
||||
return t.sendPOST(t.WebhookURL, heading, message, logger)
|
||||
default:
|
||||
return fmt.Errorf("unsupported webhook method: %s", t.WebhookMethod)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WebhookNotifier) HideSensitiveData() {
|
||||
for i := range t.Headers {
|
||||
t.Headers[i].Value = ""
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WebhookNotifier) Update(incoming *WebhookNotifier) {
|
||||
@@ -105,14 +125,15 @@ func (t *WebhookNotifier) Update(incoming *WebhookNotifier) {
|
||||
}
|
||||
|
||||
func (t *WebhookNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
if t.WebhookURL != "" {
|
||||
encrypted, err := encryptor.Encrypt(t.NotifierID, t.WebhookURL)
|
||||
for i := range t.Headers {
|
||||
if t.Headers[i].Value != "" {
|
||||
encrypted, err := encryptor.Encrypt(t.NotifierID, t.Headers[i].Value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt header value: %w", err)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt webhook URL: %w", err)
|
||||
t.Headers[i].Value = encrypted
|
||||
}
|
||||
|
||||
t.WebhookURL = encrypted
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -241,3 +262,15 @@ func escapeJSONString(s string) string {
|
||||
|
||||
return string(b[1 : len(b)-1])
|
||||
}
|
||||
|
||||
func (t *WebhookNotifier) decryptHeadersForSending(encryptor encryption.FieldEncryptor) error {
|
||||
for i := range t.Headers {
|
||||
if t.Headers[i].Value != "" {
|
||||
if decrypted, err := encryptor.Decrypt(t.NotifierID, t.Headers[i].Value); err == nil {
|
||||
t.Headers[i].Value = decrypted
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package restores
|
||||
|
||||
import (
|
||||
"context"
|
||||
"databasus-backend/internal/features/restores/enums"
|
||||
"log/slog"
|
||||
)
|
||||
@@ -10,7 +11,7 @@ type RestoreBackgroundService struct {
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *RestoreBackgroundService) Run() {
|
||||
func (s *RestoreBackgroundService) Run(ctx context.Context) {
|
||||
if err := s.failRestoresInProgress(); err != nil {
|
||||
s.logger.Error("Failed to fail restores in progress", "error", err)
|
||||
panic(err)
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"databasus-backend/internal/config"
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/databases/databases/mysql"
|
||||
@@ -274,7 +275,7 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
var backup *backups.Backup
|
||||
var backup *backups_core.Backup
|
||||
var request RestoreBackupRequest
|
||||
|
||||
if tc.dbType == databases.DatabaseTypePostgres {
|
||||
@@ -321,7 +322,7 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
|
||||
}
|
||||
|
||||
// Set huge backup size (10 TB) that would fail disk validation if checked
|
||||
repo := &backups.BackupRepository{}
|
||||
repo := &backups_core.BackupRepository{}
|
||||
backup.BackupSizeMb = 10485760.0
|
||||
err := repo.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
@@ -368,7 +369,7 @@ func createTestDatabaseWithBackupForRestore(
|
||||
workspace *workspaces_models.Workspace,
|
||||
owner *users_dto.SignInResponseDTO,
|
||||
router *gin.Engine,
|
||||
) (*databases.Database, *backups.Backup) {
|
||||
) (*databases.Database, *backups_core.Backup) {
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
@@ -504,7 +505,7 @@ func createTestStorage(workspaceID uuid.UUID) *storages.Storage {
|
||||
func createTestBackup(
|
||||
database *databases.Database,
|
||||
owner *users_dto.SignInResponseDTO,
|
||||
) *backups.Backup {
|
||||
) *backups_core.Backup {
|
||||
fieldEncryptor := util_encryption.GetFieldEncryptor()
|
||||
userService := users_services.GetUserService()
|
||||
user, err := userService.GetUserFromToken(owner.Token)
|
||||
@@ -517,17 +518,17 @@ func createTestBackup(
|
||||
panic("No storage found for workspace")
|
||||
}
|
||||
|
||||
backup := &backups.Backup{
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storages[0].ID,
|
||||
Status: backups.BackupStatusCompleted,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10.5,
|
||||
BackupDurationMs: 1000,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
repo := &backups.BackupRepository{}
|
||||
repo := &backups_core.BackupRepository{}
|
||||
if err := repo.Save(backup); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/restores/enums"
|
||||
"time"
|
||||
|
||||
@@ -13,7 +13,7 @@ type Restore struct {
|
||||
Status enums.RestoreStatus `json:"status" gorm:"column:status;type:text;not null"`
|
||||
|
||||
BackupID uuid.UUID `json:"backupId" gorm:"column:backup_id;type:uuid;not null"`
|
||||
Backup *backups.Backup
|
||||
Backup *backups_core.Backup
|
||||
|
||||
FailMessage *string `json:"failMessage" gorm:"column:fail_message"`
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package restores
|
||||
import (
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/disk"
|
||||
@@ -36,7 +37,7 @@ type RestoreService struct {
|
||||
diskService *disk.DiskService
|
||||
}
|
||||
|
||||
func (s *RestoreService) OnBeforeBackupRemove(backup *backups.Backup) error {
|
||||
func (s *RestoreService) OnBeforeBackupRemove(backup *backups_core.Backup) error {
|
||||
restores, err := s.restoreRepository.FindByBackupID(backup.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -153,10 +154,10 @@ func (s *RestoreService) RestoreBackupWithAuth(
|
||||
}
|
||||
|
||||
func (s *RestoreService) RestoreBackup(
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
requestDTO RestoreBackupRequest,
|
||||
) error {
|
||||
if backup.Status != backups.BackupStatusCompleted {
|
||||
if backup.Status != backups_core.BackupStatusCompleted {
|
||||
return errors.New("backup is not completed")
|
||||
}
|
||||
|
||||
@@ -370,7 +371,7 @@ func (s *RestoreService) validateVersionCompatibility(
|
||||
}
|
||||
|
||||
func (s *RestoreService) validateDiskSpace(
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
requestDTO RestoreBackupRequest,
|
||||
) error {
|
||||
// Only validate disk space for PostgreSQL when file-based restore is needed:
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -40,7 +40,7 @@ func (uc *RestoreMariadbBackupUsecase) Execute(
|
||||
restoringToDB *databases.Database,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore models.Restore,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
) error {
|
||||
if originalDB.Type != databases.DatabaseTypeMariadb {
|
||||
@@ -71,6 +71,7 @@ func (uc *RestoreMariadbBackupUsecase) Execute(
|
||||
|
||||
if mdb.IsHttps {
|
||||
args = append(args, "--ssl")
|
||||
args = append(args, "--skip-ssl-verify-server-cert")
|
||||
}
|
||||
|
||||
if mdb.Database != nil && *mdb.Database != "" {
|
||||
@@ -98,7 +99,7 @@ func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
|
||||
mariadbBin string,
|
||||
args []string,
|
||||
password string,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
mdbConfig *mariadbtypes.MariadbDatabase,
|
||||
) error {
|
||||
@@ -162,7 +163,7 @@ func (uc *RestoreMariadbBackupUsecase) executeMariadbRestore(
|
||||
args []string,
|
||||
myCnfFile string,
|
||||
backupReader io.ReadCloser,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
) error {
|
||||
fullArgs := append([]string{"--defaults-file=" + myCnfFile}, args...)
|
||||
|
||||
@@ -225,7 +226,7 @@ func (uc *RestoreMariadbBackupUsecase) executeMariadbRestore(
|
||||
|
||||
func (uc *RestoreMariadbBackupUsecase) setupDecryption(
|
||||
reader io.Reader,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
) (io.Reader, error) {
|
||||
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
|
||||
return nil, fmt.Errorf("backup is encrypted but missing encryption metadata")
|
||||
@@ -265,11 +266,24 @@ func (uc *RestoreMariadbBackupUsecase) createTempMyCnfFile(
|
||||
mdbConfig *mariadbtypes.MariadbDatabase,
|
||||
password string,
|
||||
) (string, error) {
|
||||
tempDir, err := os.MkdirTemp(config.GetEnv().TempFolder, "mycnf_"+uuid.New().String())
|
||||
tempFolder := config.GetEnv().TempFolder
|
||||
if err := os.MkdirAll(tempFolder, 0700); err != nil {
|
||||
return "", fmt.Errorf("failed to ensure temp folder exists: %w", err)
|
||||
}
|
||||
if err := os.Chmod(tempFolder, 0700); err != nil {
|
||||
return "", fmt.Errorf("failed to set temp folder permissions: %w", err)
|
||||
}
|
||||
|
||||
tempDir, err := os.MkdirTemp(tempFolder, "mycnf_"+uuid.New().String())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temp directory: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Chmod(tempDir, 0700); err != nil {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
return "", fmt.Errorf("failed to set temp directory permissions: %w", err)
|
||||
}
|
||||
|
||||
myCnfFile := filepath.Join(tempDir, ".my.cnf")
|
||||
|
||||
content := fmt.Sprintf(`[client]
|
||||
@@ -287,6 +301,7 @@ port=%d
|
||||
|
||||
err = os.WriteFile(myCnfFile, []byte(content), 0600)
|
||||
if err != nil {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
return "", fmt.Errorf("failed to write .my.cnf: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -40,7 +40,7 @@ func (uc *RestoreMongodbBackupUsecase) Execute(
|
||||
restoringToDB *databases.Database,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore models.Restore,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
) error {
|
||||
if originalDB.Type != databases.DatabaseTypeMongodb {
|
||||
@@ -124,7 +124,7 @@ func (uc *RestoreMongodbBackupUsecase) buildMongorestoreArgs(
|
||||
func (uc *RestoreMongodbBackupUsecase) restoreFromStorage(
|
||||
mongorestoreBin string,
|
||||
args []string,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), restoreTimeout)
|
||||
@@ -166,7 +166,7 @@ func (uc *RestoreMongodbBackupUsecase) executeMongoRestore(
|
||||
mongorestoreBin string,
|
||||
args []string,
|
||||
backupReader io.ReadCloser,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
) error {
|
||||
cmd := exec.CommandContext(ctx, mongorestoreBin, args...)
|
||||
|
||||
@@ -231,7 +231,7 @@ func (uc *RestoreMongodbBackupUsecase) executeMongoRestore(
|
||||
|
||||
func (uc *RestoreMongodbBackupUsecase) setupDecryption(
|
||||
reader io.Reader,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
) (io.Reader, error) {
|
||||
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
|
||||
return nil, errors.New("encrypted backup missing salt or IV")
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -40,7 +40,7 @@ func (uc *RestoreMysqlBackupUsecase) Execute(
|
||||
restoringToDB *databases.Database,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore models.Restore,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
) error {
|
||||
if originalDB.Type != databases.DatabaseTypeMysql {
|
||||
@@ -98,7 +98,7 @@ func (uc *RestoreMysqlBackupUsecase) restoreFromStorage(
|
||||
mysqlBin string,
|
||||
args []string,
|
||||
password string,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
myConfig *mysqltypes.MysqlDatabase,
|
||||
) error {
|
||||
@@ -154,7 +154,7 @@ func (uc *RestoreMysqlBackupUsecase) executeMysqlRestore(
|
||||
args []string,
|
||||
myCnfFile string,
|
||||
backupReader io.ReadCloser,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
) error {
|
||||
fullArgs := append([]string{"--defaults-file=" + myCnfFile}, args...)
|
||||
|
||||
@@ -217,7 +217,7 @@ func (uc *RestoreMysqlBackupUsecase) executeMysqlRestore(
|
||||
|
||||
func (uc *RestoreMysqlBackupUsecase) setupDecryption(
|
||||
reader io.Reader,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
) (io.Reader, error) {
|
||||
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
|
||||
return nil, fmt.Errorf("backup is encrypted but missing encryption metadata")
|
||||
@@ -257,11 +257,24 @@ func (uc *RestoreMysqlBackupUsecase) createTempMyCnfFile(
|
||||
myConfig *mysqltypes.MysqlDatabase,
|
||||
password string,
|
||||
) (string, error) {
|
||||
tempDir, err := os.MkdirTemp(config.GetEnv().TempFolder, "mycnf_"+uuid.New().String())
|
||||
tempFolder := config.GetEnv().TempFolder
|
||||
if err := os.MkdirAll(tempFolder, 0700); err != nil {
|
||||
return "", fmt.Errorf("failed to ensure temp folder exists: %w", err)
|
||||
}
|
||||
if err := os.Chmod(tempFolder, 0700); err != nil {
|
||||
return "", fmt.Errorf("failed to set temp folder permissions: %w", err)
|
||||
}
|
||||
|
||||
tempDir, err := os.MkdirTemp(tempFolder, "mycnf_"+uuid.New().String())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temp directory: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Chmod(tempDir, 0700); err != nil {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
return "", fmt.Errorf("failed to set temp directory permissions: %w", err)
|
||||
}
|
||||
|
||||
myCnfFile := filepath.Join(tempDir, ".my.cnf")
|
||||
|
||||
content := fmt.Sprintf(`[client]
|
||||
@@ -277,6 +290,7 @@ port=%d
|
||||
|
||||
err = os.WriteFile(myCnfFile, []byte(content), 0600)
|
||||
if err != nil {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
return "", fmt.Errorf("failed to write .my.cnf: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
"time"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -39,7 +39,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
|
||||
restoringToDB *databases.Database,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore models.Restore,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
isExcludeExtensions bool,
|
||||
) error {
|
||||
@@ -86,7 +86,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
|
||||
func (uc *RestorePostgresqlBackupUsecase) restoreCustomType(
|
||||
originalDB *databases.Database,
|
||||
pgBin string,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
pg *pgtypes.PostgresqlDatabase,
|
||||
isExcludeExtensions bool,
|
||||
@@ -113,7 +113,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreCustomType(
|
||||
func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
|
||||
originalDB *databases.Database,
|
||||
pgBin string,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
pg *pgtypes.PostgresqlDatabase,
|
||||
) error {
|
||||
@@ -321,7 +321,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
|
||||
func (uc *RestorePostgresqlBackupUsecase) restoreViaFile(
|
||||
originalDB *databases.Database,
|
||||
pgBin string,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
pg *pgtypes.PostgresqlDatabase,
|
||||
isExcludeExtensions bool,
|
||||
@@ -371,7 +371,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
|
||||
pgBin string,
|
||||
args []string,
|
||||
password string,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
pgConfig *pgtypes.PostgresqlDatabase,
|
||||
isExcludeExtensions bool,
|
||||
@@ -469,7 +469,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
|
||||
// downloadBackupToTempFile downloads backup data from storage to a temporary file
|
||||
func (uc *RestorePostgresqlBackupUsecase) downloadBackupToTempFile(
|
||||
ctx context.Context,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
) (string, func(), error) {
|
||||
// Create temporary directory for backup data
|
||||
@@ -925,14 +925,28 @@ func (uc *RestorePostgresqlBackupUsecase) createTempPgpassFile(
|
||||
escapedPassword,
|
||||
)
|
||||
|
||||
tempDir, err := os.MkdirTemp(config.GetEnv().TempFolder, "pgpass_"+uuid.New().String())
|
||||
tempFolder := config.GetEnv().TempFolder
|
||||
if err := os.MkdirAll(tempFolder, 0700); err != nil {
|
||||
return "", fmt.Errorf("failed to ensure temp folder exists: %w", err)
|
||||
}
|
||||
if err := os.Chmod(tempFolder, 0700); err != nil {
|
||||
return "", fmt.Errorf("failed to set temp folder permissions: %w", err)
|
||||
}
|
||||
|
||||
tempDir, err := os.MkdirTemp(tempFolder, "pgpass_"+uuid.New().String())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temporary directory: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Chmod(tempDir, 0700); err != nil {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
return "", fmt.Errorf("failed to set temporary directory permissions: %w", err)
|
||||
}
|
||||
|
||||
pgpassFile := filepath.Join(tempDir, ".pgpass")
|
||||
err = os.WriteFile(pgpassFile, []byte(pgpassContent), 0600)
|
||||
if err != nil {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
return "", fmt.Errorf("failed to write temporary .pgpass file: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ package usecases
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/restores/models"
|
||||
@@ -26,7 +26,7 @@ func (uc *RestoreBackupUsecase) Execute(
|
||||
restore models.Restore,
|
||||
originalDB *databases.Database,
|
||||
restoringToDB *databases.Database,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
isExcludeExtensions bool,
|
||||
) error {
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
package system_healthcheck
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
"databasus-backend/internal/features/disk"
|
||||
)
|
||||
|
||||
var healthcheckService = &HealthcheckService{
|
||||
disk.GetDiskService(),
|
||||
backups.GetBackupBackgroundService(),
|
||||
backuping.GetBackupsScheduler(),
|
||||
backuping.GetBackuperNode(),
|
||||
}
|
||||
var healthcheckController = &HealthcheckController{
|
||||
healthcheckService,
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
package system_healthcheck
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
"databasus-backend/internal/features/disk"
|
||||
"databasus-backend/internal/storage"
|
||||
"errors"
|
||||
@@ -9,7 +10,8 @@ import (
|
||||
|
||||
type HealthcheckService struct {
|
||||
diskService *disk.DiskService
|
||||
backupBackgroundService *backups.BackupBackgroundService
|
||||
backupBackgroundService *backuping.BackupsScheduler
|
||||
backuperNode *backuping.BackuperNode
|
||||
}
|
||||
|
||||
func (s *HealthcheckService) IsHealthy() error {
|
||||
@@ -29,8 +31,16 @@ func (s *HealthcheckService) IsHealthy() error {
|
||||
return errors.New("cannot connect to the database")
|
||||
}
|
||||
|
||||
if !s.backupBackgroundService.IsBackupsWorkerRunning() {
|
||||
return errors.New("backups are not running for more than 5 minutes")
|
||||
if config.GetEnv().IsPrimaryNode {
|
||||
if !s.backupBackgroundService.IsSchedulerRunning() {
|
||||
return errors.New("backups are not running for more than 5 minutes")
|
||||
}
|
||||
}
|
||||
|
||||
if config.GetEnv().IsBackupNode {
|
||||
if !s.backuperNode.IsBackuperRunning() {
|
||||
return errors.New("backuper node is not running for more than 5 minutes")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
mariadbtypes "databasus-backend/internal/features/databases/databases/mariadb"
|
||||
@@ -189,7 +189,7 @@ func testMariadbBackupRestoreForVersion(
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := "restoreddb_mariadb"
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
|
||||
@@ -286,7 +286,7 @@ func testMariadbBackupRestoreWithEncryptionForVersion(
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_config.BackupEncryptionEncrypted, backup.Encryption)
|
||||
|
||||
newDBName := "restoreddb_mariadb_encrypted"
|
||||
@@ -394,7 +394,7 @@ func testMariadbBackupRestoreWithReadOnlyUserForVersion(
|
||||
createBackupViaAPI(t, router, updatedDatabase.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, updatedDatabase.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := "restoreddb_mariadb_readonly"
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
mongodbtypes "databasus-backend/internal/features/databases/databases/mongodb"
|
||||
@@ -161,7 +161,7 @@ func testMongodbBackupRestoreForVersion(
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := "restoreddb_mongodb_" + uuid.New().String()[:8]
|
||||
|
||||
@@ -239,7 +239,7 @@ func testMongodbBackupRestoreWithEncryptionForVersion(
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_config.BackupEncryptionEncrypted, backup.Encryption)
|
||||
|
||||
newDBName := "restoreddb_mongodb_enc_" + uuid.New().String()[:8]
|
||||
@@ -328,7 +328,7 @@ func testMongodbBackupRestoreWithReadOnlyUserForVersion(
|
||||
createBackupViaAPI(t, router, updatedDatabase.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, updatedDatabase.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := "restoreddb_mongodb_ro_" + uuid.New().String()[:8]
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
mysqltypes "databasus-backend/internal/features/databases/databases/mysql"
|
||||
@@ -164,7 +164,7 @@ func testMysqlBackupRestoreForVersion(t *testing.T, mysqlVersion tools.MysqlVers
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := "restoreddb_mysql"
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
|
||||
@@ -261,7 +261,7 @@ func testMysqlBackupRestoreWithEncryptionForVersion(
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_config.BackupEncryptionEncrypted, backup.Encryption)
|
||||
|
||||
newDBName := "restoreddb_mysql_encrypted"
|
||||
@@ -369,7 +369,7 @@ func testMysqlBackupRestoreWithReadOnlyUserForVersion(
|
||||
createBackupViaAPI(t, router, updatedDatabase.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, updatedDatabase.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := "restoreddb_mysql_readonly"
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
pgtypes "databasus-backend/internal/features/databases/databases/postgresql"
|
||||
@@ -190,7 +191,7 @@ func Test_BackupAndRestoreSupabase_PublicSchemaOnly_RestoreIsSuccessful(t *testi
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
|
||||
_, err = supabaseDB.Exec(fmt.Sprintf(`DELETE FROM public.%s`, tableName))
|
||||
assert.NoError(t, err)
|
||||
@@ -410,7 +411,7 @@ func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string, cp
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := fmt.Sprintf("restoreddb_%s_cpu%d_%s", pgVersion, cpuCount, uuid.New().String()[:8])
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
|
||||
@@ -527,7 +528,7 @@ func testSchemaSelectionAllSchemasForVersion(t *testing.T, pgVersion string, por
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := fmt.Sprintf("restored_all_schemas_%s_%s", pgVersion, uuid.New().String()[:8])
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
|
||||
@@ -655,7 +656,7 @@ func testBackupRestoreWithExcludeExtensionsForVersion(t *testing.T, pgVersion st
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := fmt.Sprintf("restored_exclude_ext_%s_%s", pgVersion, uuid.New().String()[:8])
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
|
||||
@@ -789,7 +790,7 @@ func testBackupRestoreWithoutExcludeExtensionsForVersion(
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := fmt.Sprintf("restored_with_ext_%s_%s", pgVersion, uuid.New().String()[:8])
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
|
||||
@@ -928,7 +929,7 @@ func testBackupRestoreWithReadOnlyUserForVersion(t *testing.T, pgVersion string,
|
||||
createBackupViaAPI(t, router, updatedDatabase.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, updatedDatabase.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := fmt.Sprintf("restoreddb_readonly_%s", uuid.New().String()[:8])
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
|
||||
@@ -1048,7 +1049,7 @@ func testSchemaSelectionOnlySpecifiedSchemasForVersion(
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := fmt.Sprintf("restored_specific_schemas_%s_%s", pgVersion, uuid.New().String()[:8])
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
|
||||
@@ -1161,7 +1162,7 @@ func testBackupRestoreWithEncryptionForVersion(t *testing.T, pgVersion string, p
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_config.BackupEncryptionEncrypted, backup.Encryption)
|
||||
|
||||
newDBName := fmt.Sprintf("restoreddb_encrypted_%s", uuid.New().String()[:8])
|
||||
@@ -1242,7 +1243,7 @@ func waitForBackupCompletion(
|
||||
databaseID uuid.UUID,
|
||||
token string,
|
||||
timeout time.Duration,
|
||||
) *backups.Backup {
|
||||
) *backups_core.Backup {
|
||||
startTime := time.Now()
|
||||
pollInterval := 500 * time.Millisecond
|
||||
|
||||
@@ -1263,10 +1264,10 @@ func waitForBackupCompletion(
|
||||
|
||||
if len(response.Backups) > 0 {
|
||||
backup := response.Backups[0]
|
||||
if backup.Status == backups.BackupStatusCompleted {
|
||||
if backup.Status == backups_core.BackupStatusCompleted {
|
||||
return backup
|
||||
}
|
||||
if backup.Status == backups.BackupStatusFailed {
|
||||
if backup.Status == backups_core.BackupStatusFailed {
|
||||
failMsg := "unknown error"
|
||||
if backup.FailMessage != nil {
|
||||
failMsg = *backup.FailMessage
|
||||
|
||||
22
backend/internal/features/tests/setup_test.go
Normal file
22
backend/internal/features/tests/setup_test.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
backuperNode := backuping.CreateTestBackuperNode()
|
||||
cancel := backuping.StartBackuperNodeForTest(&testing.T{}, backuperNode)
|
||||
|
||||
exitCode := m.Run()
|
||||
|
||||
backuping.StopBackuperNodeForTest(&testing.T{}, cancel, backuperNode)
|
||||
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
@@ -2,13 +2,12 @@ package users_controllers
|
||||
|
||||
import (
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
)
|
||||
|
||||
var userController = &UserController{
|
||||
users_services.GetUserService(),
|
||||
rate.NewLimiter(rate.Limit(3), 3), // 3 rps with 3 burst
|
||||
cache_utils.NewRateLimiter(cache_utils.GetValkeyClient()),
|
||||
}
|
||||
|
||||
var settingsController = &SettingsController{
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
func Test_AdminLifecycleE2E_CompletesSuccessfully(t *testing.T) {
|
||||
@@ -185,7 +184,6 @@ func createUserTestRouter() *gin.Engine {
|
||||
// Register protected routes with auth middleware
|
||||
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
|
||||
GetUserController().RegisterProtectedRoutes(protected.(*gin.RouterGroup))
|
||||
GetUserController().SetSignInLimiter(rate.NewLimiter(rate.Limit(100), 100))
|
||||
|
||||
// Setup audit log service
|
||||
users_services.GetUserService().SetAuditLogWriter(&AuditLogWriterStub{})
|
||||
|
||||
@@ -3,20 +3,21 @@ package users_controllers
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
user_dto "databasus-backend/internal/features/users/dto"
|
||||
users_errors "databasus-backend/internal/features/users/errors"
|
||||
user_middleware "databasus-backend/internal/features/users/middleware"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type UserController struct {
|
||||
userService *users_services.UserService
|
||||
signinLimiter *rate.Limiter
|
||||
userService *users_services.UserService
|
||||
rateLimiter *cache_utils.RateLimiter
|
||||
}
|
||||
|
||||
func (c *UserController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
@@ -39,10 +40,6 @@ func (c *UserController) RegisterProtectedRoutes(router *gin.RouterGroup) {
|
||||
router.POST("/users/invite", c.InviteUser)
|
||||
}
|
||||
|
||||
func (c *UserController) SetSignInLimiter(limiter *rate.Limiter) {
|
||||
c.signinLimiter = limiter
|
||||
}
|
||||
|
||||
// SignUp
|
||||
// @Summary Register a new user
|
||||
// @Description Register a new user with email and password
|
||||
@@ -81,8 +78,14 @@ func (c *UserController) SignUp(ctx *gin.Context) {
|
||||
// @Failure 429 {object} map[string]string "Rate limit exceeded"
|
||||
// @Router /users/signin [post]
|
||||
func (c *UserController) SignIn(ctx *gin.Context) {
|
||||
// We use rate limiter to prevent brute force attacks
|
||||
if !c.signinLimiter.Allow() {
|
||||
var request user_dto.SignInRequestDTO
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"})
|
||||
return
|
||||
}
|
||||
|
||||
allowed, _ := c.rateLimiter.CheckLimit(request.Email, "signin", 10, 1*time.Minute)
|
||||
if !allowed {
|
||||
ctx.JSON(
|
||||
http.StatusTooManyRequests,
|
||||
gin.H{"error": "Rate limit exceeded. Please try again later."},
|
||||
@@ -90,12 +93,6 @@ func (c *UserController) SignIn(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var request user_dto.SignInRequestDTO
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.userService.SignIn(&request)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
|
||||
@@ -1146,3 +1146,48 @@ func Test_GoogleOAuth_WithInvitedUser_ActivatesUser(t *testing.T) {
|
||||
assert.Equal(t, email, response.Email)
|
||||
assert.False(t, response.IsNewUser)
|
||||
}
|
||||
|
||||
func Test_SignIn_WithExcessiveAttempts_RateLimitEnforced(t *testing.T) {
|
||||
router := createUserTestRouter()
|
||||
email := "ratelimit" + uuid.New().String() + "@example.com"
|
||||
password := "testpassword123"
|
||||
|
||||
// Create a user first
|
||||
signupRequest := users_dto.SignUpRequestDTO{
|
||||
Email: email,
|
||||
Password: password,
|
||||
Name: "Rate Limit Test User",
|
||||
}
|
||||
test_utils.MakePostRequest(t, router, "/api/v1/users/signup", "", signupRequest, http.StatusOK)
|
||||
|
||||
// Make 10 sign-in attempts (should succeed)
|
||||
for range 10 {
|
||||
signinRequest := users_dto.SignInRequestDTO{
|
||||
Email: email,
|
||||
Password: password,
|
||||
}
|
||||
test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/users/signin",
|
||||
"",
|
||||
signinRequest,
|
||||
http.StatusOK,
|
||||
)
|
||||
}
|
||||
|
||||
// 11th attempt should be rate limited
|
||||
signinRequest := users_dto.SignInRequestDTO{
|
||||
Email: email,
|
||||
Password: password,
|
||||
}
|
||||
resp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/users/signin",
|
||||
"",
|
||||
signinRequest,
|
||||
http.StatusTooManyRequests,
|
||||
)
|
||||
assert.Contains(t, string(resp.Body), "Rate limit exceeded")
|
||||
}
|
||||
|
||||
125
backend/internal/util/cache/cache.go
vendored
Normal file
125
backend/internal/util/cache/cache.go
vendored
Normal file
@@ -0,0 +1,125 @@
|
||||
package cache_utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"databasus-backend/internal/config"
|
||||
"sync"
|
||||
|
||||
"github.com/valkey-io/valkey-go"
|
||||
)
|
||||
|
||||
var (
|
||||
once sync.Once
|
||||
valkeyClient valkey.Client
|
||||
)
|
||||
|
||||
func getCache() valkey.Client {
|
||||
once.Do(func() {
|
||||
env := config.GetEnv()
|
||||
|
||||
options := valkey.ClientOption{
|
||||
InitAddress: []string{env.ValkeyHost + ":" + env.ValkeyPort},
|
||||
Password: env.ValkeyPassword,
|
||||
Username: env.ValkeyUsername,
|
||||
}
|
||||
|
||||
if env.ValkeyIsSsl {
|
||||
options.TLSConfig = &tls.Config{
|
||||
ServerName: env.ValkeyHost,
|
||||
}
|
||||
}
|
||||
|
||||
client, err := valkey.NewClient(options)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
valkeyClient = client
|
||||
})
|
||||
|
||||
return valkeyClient
|
||||
}
|
||||
|
||||
func GetValkeyClient() valkey.Client {
|
||||
return getCache()
|
||||
}
|
||||
|
||||
func TestCacheConnection() {
|
||||
// Get Valkey client from cache package
|
||||
client := getCache()
|
||||
|
||||
// Create a simple test cache util for strings
|
||||
cacheUtil := NewCacheUtil[string](client, "test:")
|
||||
|
||||
// Test data
|
||||
testKey := "connection_test"
|
||||
testValue := "valkey_is_working"
|
||||
|
||||
// Test Set operation
|
||||
cacheUtil.Set(testKey, &testValue)
|
||||
|
||||
// Test Get operation
|
||||
retrievedValue := cacheUtil.Get(testKey)
|
||||
|
||||
// Verify the value was retrieved correctly
|
||||
if retrievedValue == nil {
|
||||
panic("Cache test failed: could not retrieve cached value")
|
||||
}
|
||||
|
||||
if *retrievedValue != testValue {
|
||||
panic("Cache test failed: retrieved value does not match expected")
|
||||
}
|
||||
|
||||
// Clean up - remove test key
|
||||
cacheUtil.Invalidate(testKey)
|
||||
|
||||
// Verify cleanup worked
|
||||
cleanupCheck := cacheUtil.Get(testKey)
|
||||
if cleanupCheck != nil {
|
||||
panic("Cache test failed: test key was not properly invalidated")
|
||||
}
|
||||
}
|
||||
|
||||
func ClearAllCache() error {
|
||||
pattern := "*"
|
||||
cursor := uint64(0)
|
||||
batchSize := int64(100)
|
||||
|
||||
cacheUtil := NewCacheUtil[string](getCache(), "")
|
||||
|
||||
for {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), DefaultQueueTimeout)
|
||||
|
||||
result := cacheUtil.client.Do(
|
||||
ctx,
|
||||
cacheUtil.client.B().Scan().Cursor(cursor).Match(pattern).Count(batchSize).Build(),
|
||||
)
|
||||
cancel()
|
||||
|
||||
if result.Error() != nil {
|
||||
return result.Error()
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(scanResult.Elements) > 0 {
|
||||
delCtx, delCancel := context.WithTimeout(context.Background(), cacheUtil.timeout)
|
||||
cacheUtil.client.Do(
|
||||
delCtx,
|
||||
cacheUtil.client.B().Del().Key(scanResult.Elements...).Build(),
|
||||
)
|
||||
delCancel()
|
||||
}
|
||||
|
||||
cursor = scanResult.Cursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
109
backend/internal/util/cache/pubsub.go
vendored
Normal file
109
backend/internal/util/cache/pubsub.go
vendored
Normal file
@@ -0,0 +1,109 @@
|
||||
package cache_utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
"databasus-backend/internal/util/logger"
|
||||
|
||||
"github.com/valkey-io/valkey-go"
|
||||
)
|
||||
|
||||
type PubSubManager struct {
|
||||
client valkey.Client
|
||||
subscriptions map[string]context.CancelFunc
|
||||
mu sync.RWMutex
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewPubSubManager() *PubSubManager {
|
||||
return &PubSubManager{
|
||||
client: getCache(),
|
||||
subscriptions: make(map[string]context.CancelFunc),
|
||||
logger: logger.GetLogger(),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *PubSubManager) Subscribe(
|
||||
ctx context.Context,
|
||||
channel string,
|
||||
handler func(message string),
|
||||
) error {
|
||||
m.mu.Lock()
|
||||
if _, exists := m.subscriptions[channel]; exists {
|
||||
m.mu.Unlock()
|
||||
return fmt.Errorf("already subscribed to channel: %s", channel)
|
||||
}
|
||||
|
||||
subCtx, cancel := context.WithCancel(ctx)
|
||||
m.subscriptions[channel] = cancel
|
||||
m.mu.Unlock()
|
||||
|
||||
go m.subscriptionLoop(subCtx, channel, handler)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *PubSubManager) Publish(ctx context.Context, channel string, message string) error {
|
||||
cmd := m.client.B().Publish().Channel(channel).Message(message).Build()
|
||||
result := m.client.Do(ctx, cmd)
|
||||
|
||||
if err := result.Error(); err != nil {
|
||||
m.logger.Error("Failed to publish message to Redis", "channel", channel, "error", err)
|
||||
return fmt.Errorf("failed to publish message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *PubSubManager) Close() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for channel, cancel := range m.subscriptions {
|
||||
cancel()
|
||||
delete(m.subscriptions, channel)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *PubSubManager) subscriptionLoop(
|
||||
ctx context.Context,
|
||||
channel string,
|
||||
handler func(message string),
|
||||
) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
m.logger.Error("Panic in subscription handler", "channel", channel, "panic", r)
|
||||
}
|
||||
}()
|
||||
|
||||
m.logger.Info("Starting subscription", "channel", channel)
|
||||
|
||||
err := m.client.Receive(
|
||||
ctx,
|
||||
m.client.B().Subscribe().Channel(channel).Build(),
|
||||
func(msg valkey.PubSubMessage) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
m.logger.Error("Panic in message handler", "channel", channel, "panic", r)
|
||||
}
|
||||
}()
|
||||
|
||||
handler(msg.Message)
|
||||
},
|
||||
)
|
||||
|
||||
if err != nil && ctx.Err() == nil {
|
||||
m.logger.Error("Subscription error", "channel", channel, "error", err)
|
||||
} else if ctx.Err() != nil {
|
||||
m.logger.Info("Subscription cancelled", "channel", channel)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
delete(m.subscriptions, channel)
|
||||
m.mu.Unlock()
|
||||
}
|
||||
85
backend/internal/util/cache/rate_limiter.go
vendored
Normal file
85
backend/internal/util/cache/rate_limiter.go
vendored
Normal file
@@ -0,0 +1,85 @@
|
||||
package cache_utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/valkey-io/valkey-go"
|
||||
)
|
||||
|
||||
type RateLimiter struct {
|
||||
client valkey.Client
|
||||
}
|
||||
|
||||
func NewRateLimiter(client valkey.Client) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RateLimiter) CheckLimit(
|
||||
identifier string,
|
||||
endpoint string,
|
||||
maxRequests int,
|
||||
windowDuration time.Duration,
|
||||
) (bool, error) {
|
||||
requestID := uuid.New().String()
|
||||
keyPrefix := fmt.Sprintf("ratelimit:%s:%s", endpoint, identifier)
|
||||
fullKey := fmt.Sprintf("%s:%s", keyPrefix, requestID)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), DefaultCacheTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Set the key with TTL
|
||||
setCmd := r.client.B().
|
||||
Set().
|
||||
Key(fullKey).
|
||||
Value("1").
|
||||
ExSeconds(int64(windowDuration.Seconds())).
|
||||
Build()
|
||||
if err := r.client.Do(ctx, setCmd).Error(); err != nil {
|
||||
return true, fmt.Errorf("failed to set rate limit key: %w", err)
|
||||
}
|
||||
|
||||
// Count keys matching the pattern
|
||||
count, err := r.countKeys(keyPrefix)
|
||||
if err != nil {
|
||||
return true, fmt.Errorf("failed to count rate limit keys: %w", err)
|
||||
}
|
||||
|
||||
return count <= maxRequests, nil
|
||||
}
|
||||
|
||||
func (r *RateLimiter) countKeys(keyPrefix string) (int, error) {
|
||||
pattern := keyPrefix + ":*"
|
||||
cursor := uint64(0)
|
||||
totalCount := 0
|
||||
|
||||
for {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), DefaultCacheTimeout)
|
||||
|
||||
scanCmd := r.client.B().Scan().Cursor(cursor).Match(pattern).Count(100).Build()
|
||||
result := r.client.Do(ctx, scanCmd)
|
||||
cancel()
|
||||
|
||||
if result.Error() != nil {
|
||||
return 0, result.Error()
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
totalCount += len(scanResult.Elements)
|
||||
cursor = scanResult.Cursor
|
||||
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return totalCount, nil
|
||||
}
|
||||
76
backend/internal/util/cache/utils.go
vendored
Normal file
76
backend/internal/util/cache/utils.go
vendored
Normal file
@@ -0,0 +1,76 @@
|
||||
package cache_utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/valkey-io/valkey-go"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultCacheTimeout = 10 * time.Second
|
||||
DefaultCacheExpiry = 10 * time.Minute
|
||||
DefaultQueueTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type CacheUtil[T any] struct {
|
||||
client valkey.Client
|
||||
prefix string
|
||||
timeout time.Duration
|
||||
expiry time.Duration
|
||||
}
|
||||
|
||||
func NewCacheUtil[T any](client valkey.Client, prefix string) *CacheUtil[T] {
|
||||
return &CacheUtil[T]{
|
||||
client: client,
|
||||
prefix: prefix,
|
||||
timeout: DefaultCacheTimeout,
|
||||
expiry: DefaultCacheExpiry,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CacheUtil[T]) Get(key string) *T {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
|
||||
defer cancel()
|
||||
|
||||
fullKey := c.prefix + key
|
||||
result := c.client.Do(ctx, c.client.B().Get().Key(fullKey).Build())
|
||||
|
||||
if result.Error() != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
data, err := result.AsBytes()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var item T
|
||||
if err := json.Unmarshal(data, &item); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &item
|
||||
}
|
||||
|
||||
func (c *CacheUtil[T]) Set(key string, item *T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
|
||||
defer cancel()
|
||||
|
||||
data, err := json.Marshal(item)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
fullKey := c.prefix + key
|
||||
c.client.Do(ctx, c.client.B().Set().Key(fullKey).Value(string(data)).Ex(c.expiry).Build())
|
||||
}
|
||||
|
||||
func (c *CacheUtil[T]) Invalidate(key string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
|
||||
defer cancel()
|
||||
|
||||
fullKey := c.prefix + key
|
||||
c.client.Do(ctx, c.client.B().Del().Key(fullKey).Build())
|
||||
}
|
||||
@@ -40,7 +40,3 @@ export const GOOGLE_CLIENT_ID =
|
||||
export function getOAuthRedirectUri(): string {
|
||||
return `${window.location.origin}/auth/callback`;
|
||||
}
|
||||
|
||||
export function isOAuthEnabled(): boolean {
|
||||
return IS_CLOUD && (!!GITHUB_CLIENT_ID || !!GOOGLE_CLIENT_ID);
|
||||
}
|
||||
|
||||
@@ -112,6 +112,12 @@ export function EditWebhookNotifierComponent({ notifier, setNotifier, setUnsaved
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{notifier.id && (
|
||||
<div className="mb-1 text-xs text-orange-700">
|
||||
*Saved headers hidden for security reasons
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="w-full max-w-[500px]">
|
||||
{headers.map((header: WebhookHeader, index: number) => (
|
||||
<div key={index} className="mb-1 flex items-center gap-2">
|
||||
@@ -204,11 +210,12 @@ export function EditWebhookNotifierComponent({ notifier, setNotifier, setUnsaved
|
||||
<div className="text-xs font-semibold text-gray-500 dark:text-gray-400">
|
||||
Headers:
|
||||
</div>
|
||||
|
||||
{headers
|
||||
.filter((h) => h.key)
|
||||
.map((h, i) => (
|
||||
<div key={i} className="text-xs">
|
||||
{h.key}: {h.value || '(empty)'}
|
||||
{h.key}: {h.value || '(hidden)'}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
@@ -29,7 +29,7 @@ export function ShowWebhookNotifierComponent({ notifier }: Props) {
|
||||
.filter((h: WebhookHeader) => h.key)
|
||||
.map((h: WebhookHeader, i: number) => (
|
||||
<div key={i} className="text-gray-600">
|
||||
<span className="font-medium">{h.key}:</span> {h.value || '(empty)'}
|
||||
<span className="font-medium">{h.key}:</span> {h.value || '(hidden)'}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
import { GithubOutlined, GoogleOutlined } from '@ant-design/icons';
|
||||
import { Button, message } from 'antd';
|
||||
|
||||
import {
|
||||
GITHUB_CLIENT_ID,
|
||||
GOOGLE_CLIENT_ID,
|
||||
getOAuthRedirectUri,
|
||||
isOAuthEnabled,
|
||||
} from '../../../constants';
|
||||
|
||||
export function OauthComponent() {
|
||||
if (!isOAuthEnabled()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const redirectUri = getOAuthRedirectUri();
|
||||
|
||||
const handleGitHubLogin = () => {
|
||||
if (!GITHUB_CLIENT_ID) {
|
||||
message.error('GitHub OAuth is not configured');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const params = new URLSearchParams({
|
||||
client_id: GITHUB_CLIENT_ID,
|
||||
redirect_uri: redirectUri,
|
||||
state: 'github',
|
||||
scope: 'user:email',
|
||||
});
|
||||
|
||||
const githubAuthUrl = `https://github.com/login/oauth/authorize?${params.toString()}`;
|
||||
|
||||
// Validate URL is properly formed
|
||||
new URL(githubAuthUrl);
|
||||
window.location.href = githubAuthUrl;
|
||||
} catch (error) {
|
||||
message.error('Invalid OAuth configuration');
|
||||
console.error('GitHub OAuth URL error:', error);
|
||||
}
|
||||
};
|
||||
|
||||
const handleGoogleLogin = () => {
|
||||
if (!GOOGLE_CLIENT_ID) {
|
||||
message.error('Google OAuth is not configured');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const params = new URLSearchParams({
|
||||
client_id: GOOGLE_CLIENT_ID,
|
||||
redirect_uri: redirectUri,
|
||||
response_type: 'code',
|
||||
scope: 'openid email profile',
|
||||
state: 'google',
|
||||
});
|
||||
|
||||
const googleAuthUrl = `https://accounts.google.com/o/oauth2/v2/auth?${params.toString()}`;
|
||||
|
||||
// Validate URL is properly formed
|
||||
new URL(googleAuthUrl);
|
||||
window.location.href = googleAuthUrl;
|
||||
} catch (error) {
|
||||
message.error('Invalid OAuth configuration');
|
||||
console.error('Google OAuth URL error:', error);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="mt-4">
|
||||
<div className="space-y-2">
|
||||
{GITHUB_CLIENT_ID && (
|
||||
<Button
|
||||
icon={<GithubOutlined />}
|
||||
onClick={handleGitHubLogin}
|
||||
className="w-full"
|
||||
size="large"
|
||||
>
|
||||
Continue with GitHub
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{GOOGLE_CLIENT_ID && (
|
||||
<Button
|
||||
icon={<GoogleOutlined />}
|
||||
onClick={handleGoogleLogin}
|
||||
className="w-full"
|
||||
size="large"
|
||||
>
|
||||
Continue with Google
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -2,11 +2,12 @@ import { EyeInvisibleOutlined, EyeTwoTone } from '@ant-design/icons';
|
||||
import { Button, Input } from 'antd';
|
||||
import { type JSX, useState } from 'react';
|
||||
|
||||
import { IS_CLOUD } from '../../../constants';
|
||||
import { GITHUB_CLIENT_ID, GOOGLE_CLIENT_ID } from '../../../constants';
|
||||
import { userApi } from '../../../entity/users';
|
||||
import { StringUtils } from '../../../shared/lib';
|
||||
import { FormValidator } from '../../../shared/lib/FormValidator';
|
||||
import { OauthComponent } from './OauthComponent';
|
||||
import { GithubOAuthComponent } from './oauth/GithubOAuthComponent';
|
||||
import { GoogleOAuthComponent } from './oauth/GoogleOAuthComponent';
|
||||
|
||||
interface SignInComponentProps {
|
||||
onSwitchToSignUp?: () => void;
|
||||
@@ -67,9 +68,14 @@ export function SignInComponent({ onSwitchToSignUp }: SignInComponentProps): JSX
|
||||
<div className="w-full max-w-[300px]">
|
||||
<div className="mb-5 text-center text-2xl font-bold">Sign in</div>
|
||||
|
||||
<OauthComponent />
|
||||
<div className="mt-4">
|
||||
<div className="space-y-2">
|
||||
<GithubOAuthComponent />
|
||||
<GoogleOAuthComponent />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{IS_CLOUD && (
|
||||
{(GOOGLE_CLIENT_ID || GITHUB_CLIENT_ID) && (
|
||||
<div className="relative my-6">
|
||||
<div className="absolute inset-0 flex items-center">
|
||||
<div className="w-full border-t border-gray-300"></div>
|
||||
|
||||
@@ -2,11 +2,12 @@ import { EyeInvisibleOutlined, EyeTwoTone } from '@ant-design/icons';
|
||||
import { App, Button, Input } from 'antd';
|
||||
import { type JSX, useState } from 'react';
|
||||
|
||||
import { IS_CLOUD } from '../../../constants';
|
||||
import { GITHUB_CLIENT_ID, GOOGLE_CLIENT_ID } from '../../../constants';
|
||||
import { userApi } from '../../../entity/users';
|
||||
import { StringUtils } from '../../../shared/lib';
|
||||
import { FormValidator } from '../../../shared/lib/FormValidator';
|
||||
import { OauthComponent } from './OauthComponent';
|
||||
import { GithubOAuthComponent } from './oauth/GithubOAuthComponent';
|
||||
import { GoogleOAuthComponent } from './oauth/GoogleOAuthComponent';
|
||||
|
||||
interface SignUpComponentProps {
|
||||
onSwitchToSignIn?: () => void;
|
||||
@@ -98,9 +99,14 @@ export function SignUpComponent({ onSwitchToSignIn }: SignUpComponentProps): JSX
|
||||
<div className="w-full max-w-[300px]">
|
||||
<div className="mb-5 text-center text-2xl font-bold">Sign up</div>
|
||||
|
||||
<OauthComponent />
|
||||
<div className="mt-4">
|
||||
<div className="space-y-2">
|
||||
<GithubOAuthComponent />
|
||||
<GoogleOAuthComponent />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{IS_CLOUD && (
|
||||
{(GOOGLE_CLIENT_ID || GITHUB_CLIENT_ID) && (
|
||||
<div className="relative my-6">
|
||||
<div className="absolute inset-0 flex items-center">
|
||||
<div className="w-full border-t border-gray-300"></div>
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
import { GithubOutlined } from '@ant-design/icons';
|
||||
import { Button, message } from 'antd';
|
||||
|
||||
import { GITHUB_CLIENT_ID, getOAuthRedirectUri } from '../../../../constants';
|
||||
|
||||
export function GithubOAuthComponent() {
|
||||
if (!GITHUB_CLIENT_ID) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const redirectUri = getOAuthRedirectUri();
|
||||
|
||||
const handleGitHubLogin = () => {
|
||||
try {
|
||||
const params = new URLSearchParams({
|
||||
client_id: GITHUB_CLIENT_ID,
|
||||
redirect_uri: redirectUri,
|
||||
state: 'github',
|
||||
scope: 'user:email',
|
||||
});
|
||||
|
||||
const githubAuthUrl = `https://github.com/login/oauth/authorize?${params.toString()}`;
|
||||
|
||||
// Validate URL is properly formed
|
||||
new URL(githubAuthUrl);
|
||||
window.location.href = githubAuthUrl;
|
||||
} catch (error) {
|
||||
message.error('Invalid OAuth configuration');
|
||||
console.error('GitHub OAuth URL error:', error);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Button icon={<GithubOutlined />} onClick={handleGitHubLogin} className="w-full" size="large">
|
||||
Continue with GitHub
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
import { GoogleOutlined } from '@ant-design/icons';
|
||||
import { Button, message } from 'antd';
|
||||
|
||||
import { GOOGLE_CLIENT_ID, getOAuthRedirectUri } from '../../../../constants';
|
||||
|
||||
export function GoogleOAuthComponent() {
|
||||
if (!GOOGLE_CLIENT_ID) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const redirectUri = getOAuthRedirectUri();
|
||||
|
||||
const handleGoogleLogin = () => {
|
||||
try {
|
||||
const params = new URLSearchParams({
|
||||
client_id: GOOGLE_CLIENT_ID,
|
||||
redirect_uri: redirectUri,
|
||||
response_type: 'code',
|
||||
scope: 'openid email profile',
|
||||
state: 'google',
|
||||
});
|
||||
|
||||
const googleAuthUrl = `https://accounts.google.com/o/oauth2/v2/auth?${params.toString()}`;
|
||||
|
||||
// Validate URL is properly formed
|
||||
new URL(googleAuthUrl);
|
||||
window.location.href = googleAuthUrl;
|
||||
} catch (error) {
|
||||
message.error('Invalid OAuth configuration');
|
||||
console.error('Google OAuth URL error:', error);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Button icon={<GoogleOutlined />} onClick={handleGoogleLogin} className="w-full" size="large">
|
||||
Continue with Google
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
Reference in New Issue
Block a user