Compare commits

...

49 Commits

Author SHA1 Message Date
Rostislav Dugin
bc870b3f8e Merge pull request #261 from databasus/develop
FIX (webhook): Update webhook tests to not expect URL to be encrypted
2026-01-14 09:43:26 +03:00
Rostislav Dugin
15383c59eb FIX (webhook): Update webhook tests to not expect URL to be encrypted 2026-01-14 09:42:25 +03:00
Rostislav Dugin
d14c223a65 Merge pull request #259 from databasus/develop
Develop
2026-01-14 09:10:28 +03:00
Rostislav Dugin
2c0a294027 FIX (webhook): Do not encypt webhook URL, keep encyption for headers only 2026-01-14 09:09:00 +03:00
Rostislav Dugin
5d851d73bd FIX (mysql \ mariadb): Decrease strictness of SELECT check for health check 2026-01-14 08:39:27 +03:00
Rostislav Dugin
699913c251 FIX (postgresql): Filter TEMP table SELECT checks 2026-01-14 07:42:29 +03:00
Rostislav Dugin
a2e3f30a6d Merge pull request #258 from databasus/develop
FEATURE (backups): Add support of multinode Databasus setup
2026-01-14 07:34:06 +03:00
Rostislav Dugin
80f1174ecd FEATURE (backups): Add support of multinode Databasus setup 2026-01-14 07:32:13 +03:00
Rostislav Dugin
a47f8d5e2c Merge pull request #253 from databasus/develop
FIX (permissions check): Check permissions only in schemas selected f…
2026-01-12 14:23:24 +03:00
Rostislav Dugin
54b9e67656 FIX (permissions check): Check permissions only in schemas selected for backup 2026-01-12 14:22:12 +03:00
Rostislav Dugin
3782846872 Merge pull request #251 from databasus/develop
FIX (tidy): Run go mod tidy
2026-01-12 11:32:25 +03:00
Rostislav Dugin
245a81897f FIX (tidy): Run go mod tidy 2026-01-12 11:31:52 +03:00
Rostislav Dugin
5cbc0773b6 Merge pull request #250 from databasus/develop
FEATURE (backups): Move backups cancellation to Valkey pub\sub
2026-01-12 11:26:29 +03:00
Rostislav Dugin
997fc01442 FEATURE (backups): Move backups cancellation to Valkey pub\sub 2026-01-12 11:24:25 +03:00
Rostislav Dugin
6d0ae32d0c Merge pull request #240 from databasus/develop
FIX (oauth): Enable GitHub and Google OAuth
2026-01-10 20:15:43 +03:00
Rostislav Dugin
011985d723 FIX (oauth): Enable GitHub and Google OAuth 2026-01-10 19:19:37 +03:00
Rostislav Dugin
d677ee61de Merge pull request #239 from databasus/develop
FIX (mariadb): --skip-ssl-verify-server-cert for mariadb / mysql
2026-01-10 18:34:58 +03:00
Rostislav Dugin
c6b8f6e87a Merge pull request #237 from wzzrd/bugfix/disable_mariadb_mysql_ssl_verify
--skip-ssl-verify-server-cert for mariadb
2026-01-10 18:33:45 +03:00
Maxim Burgerhout
2bb5f93d00 --skip-ssl-verify-server-cert for mariadb / mysql
This change adds the --skip-ssl-verify-server-cert flag to mariadb
database connections for both backups and restores. This errors when
trying to verify certificates during those procedures.
2026-01-10 15:50:09 +01:00
Rostislav Dugin
b91c150300 Merge pull request #236 from databasus/develop
Develop
2026-01-10 15:19:19 +03:00
Rostislav Dugin
12b119ce40 FIX (readme): Update readme 2026-01-10 15:16:25 +03:00
Rostislav Dugin
7c6f0ab4ba FIX (mysql\mariadb): Use custom TLS handler to skip verification instead of build-in 2026-01-10 15:13:47 +03:00
Rostislav Dugin
6d2db4b298 Merge pull request #232 from databasus/develop
Develop
2026-01-09 11:12:27 +03:00
Rostislav Dugin
6397423298 FIX (temp folder): Ensure permissions 0700 for temp folders to meet PG requirements for .pgpass ownership 2026-01-09 11:10:50 +03:00
Rostislav Dugin
3470aae8e3 FIX (mysql\mariadb): Remove PROCESS permission check before backup, because it is not mandatory for backup 2026-01-09 11:02:14 +03:00
Rostislav Dugin
184fbcdb2c Merge pull request #230 from databasus/develop
FIX (temp): Use Databasus temp folder instead of system over PG backup
2026-01-08 20:41:45 +03:00
Rostislav Dugin
2d897dd722 FIX (temp): Use Databasus temp folder instead of system over PG backup 2026-01-08 20:40:22 +03:00
Rostislav Dugin
cba40afd00 Merge pull request #228 from databasus/develop
FIX (backend): Fix formatting
2026-01-08 17:11:43 +03:00
Rostislav Dugin
7aea012aeb FIX (backend): Fix formatting 2026-01-08 17:10:47 +03:00
Rostislav Dugin
6d5534deaa Merge pull request #227 from databasus/develop
Develop
2026-01-08 16:55:12 +03:00
Rostislav Dugin
c04bd54683 FIX (download): Add streamable download of backups 2026-01-08 15:55:52 +03:00
Rostislav Dugin
1c3f16b372 FIX (google drive): Fix UI after new local redirect PR 2026-01-08 12:22:47 +03:00
Rostislav Dugin
ed08da56a6 FIX (cicd): Get rid of CITATION auto generate 2026-01-08 11:35:55 +03:00
Rostislav Dugin
c53e84b48d FIX (devex): Fix Linux tools installation script 2026-01-08 11:34:35 +03:00
Rostislav Dugin
dbfeb9e27f merge develop 2026-01-08 11:33:34 +03:00
Rostislav Dugin
02e86ffb3b FIX (devex): Fix Linux tools installation script 2026-01-08 11:10:56 +03:00
github-actions[bot]
207382116c Update CITATION.cff to v2.21.0 2026-01-05 18:38:28 +00:00
Rostislav Dugin
a91ee50e31 Merge pull request #221 from databasus/develop
Develop
2026-01-05 21:08:50 +03:00
Rostislav Dugin
7e5562b115 FEATURE (mysql): Add automatic detection of allowed privileges to backup proper DB items 2026-01-05 21:07:53 +03:00
Rostislav Dugin
3ef51c4d68 FEATURE (databases): Imrove check for required permissions to backup, check for read-only user and extend DBs models tests 2026-01-05 21:07:53 +03:00
github-actions[bot]
e47e513460 Update CITATION.cff to v2.20.3 2026-01-04 21:28:24 +00:00
Rostislav Dugin
226a6c06e6 Merge pull request #216 from databasus/develop
FIX (readonly user): Improve complexity of readonly user passwords to…
2026-01-05 00:07:37 +03:00
Rostislav Dugin
615fd9d574 FIX (readonly user): Improve complexity of readonly user passwords to pass Google Cloud requirements 2026-01-05 00:06:24 +03:00
github-actions[bot]
e9fcf20cdf Update CITATION.cff to v2.20.2 2026-01-04 20:14:38 +00:00
Rostislav Dugin
7649f4acfd Merge pull request #214 from databasus/develop
FIX (databases): Add timeout for deletion in case of storage stuck
2026-01-04 22:54:13 +03:00
Rostislav Dugin
7e4c3bcc19 FIX (databases): Add timeout for deletion in case of storage stuck 2026-01-04 22:51:11 +03:00
Rostislav Dugin
f2aecc0427 Merge pull request #212 from databasus/develop
FIX (mariadb): Add events exclusion for MariaDB
2026-01-04 22:16:24 +03:00
Rostislav Dugin
3ce7da319f FIX (mariadb): Add events exclusion for MariaDB 2026-01-04 22:15:31 +03:00
github-actions[bot]
096098f660 Update CITATION.cff to v2.20.1 2026-01-04 18:13:01 +00:00
134 changed files with 9387 additions and 2787 deletions

View File

@@ -17,7 +17,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: "1.24.4"
go-version: "1.24.9"
- name: Cache Go modules
uses: actions/cache@v4
@@ -134,7 +134,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: "1.24.4"
go-version: "1.24.9"
- name: Cache Go modules
uses: actions/cache@v4
@@ -221,6 +221,12 @@ jobs:
TEST_MONGODB_60_PORT=27060
TEST_MONGODB_70_PORT=27070
TEST_MONGODB_82_PORT=27082
# Valkey (cache)
VALKEY_HOST=localhost
VALKEY_PORT=6379
VALKEY_USERNAME=
VALKEY_PASSWORD=
VALKEY_IS_SSL=false
EOF
- name: Start test containers
@@ -233,6 +239,11 @@ jobs:
# Wait for main dev database
timeout 60 bash -c 'until docker exec dev-db pg_isready -h localhost -p 5437 -U postgres; do sleep 2; done'
# Wait for Valkey (cache)
echo "Waiting for Valkey..."
timeout 60 bash -c 'until docker exec dev-valkey valkey-cli ping 2>/dev/null | grep -q PONG; do sleep 2; done'
echo "Valkey is ready!"
# Wait for test databases
timeout 60 bash -c 'until nc -z localhost 5000; do sleep 2; done'
timeout 60 bash -c 'until nc -z localhost 5001; do sleep 2; done'
@@ -672,17 +683,6 @@ jobs:
echo EOF
} >> $GITHUB_OUTPUT
- name: Update CITATION.cff version
run: |
VERSION="${{ needs.determine-version.outputs.new_version }}"
sed -i "s/^version: .*/version: ${VERSION}/" CITATION.cff
sed -i "s/^date-released: .*/date-released: \"$(date +%Y-%m-%d)\"/" CITATION.cff
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
git add CITATION.cff
git commit -m "Update CITATION.cff to v${VERSION}" || true
git push || true
- name: Create GitHub Release
uses: actions/create-release@v1
env:

View File

@@ -6,14 +6,14 @@ repos:
hooks:
- id: frontend-format
name: Frontend Format (Prettier)
entry: powershell -Command "cd frontend; npm run format"
entry: bash -c "cd frontend && npm run format"
language: system
files: ^frontend/.*\.(ts|tsx|js|jsx|json|css|md)$
pass_filenames: false
- id: frontend-lint
name: Frontend Lint (ESLint)
entry: powershell -Command "cd frontend; npm run lint"
entry: bash -c "cd frontend && npm run lint"
language: system
files: ^frontend/.*\.(ts|tsx|js|jsx)$
pass_filenames: false
@@ -23,7 +23,7 @@ repos:
hooks:
- id: backend-format-and-lint
name: Backend Format & Lint (golangci-lint)
entry: powershell -Command "cd backend; golangci-lint fmt; golangci-lint run"
entry: bash -c "cd backend && golangci-lint fmt ./internal/... ./cmd/... && golangci-lint run ./internal/... ./cmd/..."
language: system
files: ^backend/.*\.go$
pass_filenames: false
pass_filenames: false

View File

@@ -32,5 +32,5 @@ keywords:
- mongodb
- mariadb
license: Apache-2.0
version: 2.20.0
date-released: "2026-01-04"
version: 2.21.0
date-released: "2026-01-05"

View File

@@ -22,7 +22,7 @@ RUN npm run build
# ========= BUILD BACKEND =========
# Backend build stage
FROM --platform=$BUILDPLATFORM golang:1.24.4 AS backend-build
FROM --platform=$BUILDPLATFORM golang:1.24.9 AS backend-build
# Make TARGET args available early so tools built here match the final image arch
ARG TARGETOS
@@ -123,6 +123,15 @@ RUN wget -qO- https://www.postgresql.org/media/keys/ACCC4CF8.asc | apt-key add -
apt-get install -y --no-install-recommends postgresql-17 && \
rm -rf /var/lib/apt/lists/*
# Install Valkey server from debian repository
# Valkey is only accessible internally (localhost) - not exposed outside container
RUN wget -O /usr/share/keyrings/greensec.github.io-valkey-debian.key https://greensec.github.io/valkey-debian/public.key && \
echo "deb [signed-by=/usr/share/keyrings/greensec.github.io-valkey-debian.key] https://greensec.github.io/valkey-debian/repo $(lsb_release -cs) main" \
> /etc/apt/sources.list.d/valkey-debian.list && \
apt-get update && \
apt-get install -y --no-install-recommends valkey && \
rm -rf /var/lib/apt/lists/*
# ========= Install rclone =========
RUN apt-get update && \
apt-get install -y --no-install-recommends rclone && \
@@ -245,7 +254,34 @@ PG_BIN="/usr/lib/postgresql/17/bin"
# Ensure proper ownership of data directory
echo "Setting up data directory permissions..."
mkdir -p /databasus-data/pgdata
mkdir -p /databasus-data/temp
mkdir -p /databasus-data/backups
chown -R postgres:postgres /databasus-data
chmod 700 /databasus-data/temp
# ========= Start Valkey (internal cache) =========
echo "Configuring Valkey cache..."
cat > /tmp/valkey.conf << 'VALKEY_CONFIG'
port 6379
bind 127.0.0.1
protected-mode yes
save ""
maxmemory 256mb
maxmemory-policy allkeys-lru
VALKEY_CONFIG
echo "Starting Valkey..."
valkey-server /tmp/valkey.conf &
VALKEY_PID=\$!
echo "Waiting for Valkey to be ready..."
for i in {1..30}; do
if valkey-cli ping >/dev/null 2>&1; then
echo "Valkey is ready!"
break
fi
sleep 1
done
# Initialize PostgreSQL if not already initialized
if [ ! -s "/databasus-data/pgdata/PG_VERSION" ]; then

View File

@@ -2,7 +2,7 @@
<img src="assets/logo.svg" alt="Databasus Logo" width="250"/>
<h3>Backup tool for PostgreSQL, MySQL and MongoDB</h3>
<p>Databasus is a free, open source and self-hosted tool to backup databases. Make backups with different storages (S3, Google Drive, FTP, etc.) and notifications about progress (Slack, Discord, Telegram, etc.). Previously known as Postgresus (see migration guide).</p>
<p>Databasus is a free, open source and self-hosted tool to backup databases (with focus on PostgreSQL). Make backups with different storages (S3, Google Drive, FTP, etc.) and notifications about progress (Slack, Discord, Telegram, etc.). Previously known as Postgresus (see migration guide).</p>
<!-- Badges -->
[![PostgreSQL](https://img.shields.io/badge/PostgreSQL-336791?logo=postgresql&logoColor=white)](https://www.postgresql.org/)

View File

@@ -11,6 +11,12 @@ DATABASE_URL=postgres://postgres:Q1234567@dev-db:5437/databasus?sslmode=disable
GOOSE_DRIVER=postgres
GOOSE_DBSTRING=postgres://postgres:Q1234567@dev-db:5437/databasus?sslmode=disable
GOOSE_MIGRATION_DIR=./migrations
# valkey
VALKEY_HOST=127.0.0.1
VALKEY_PORT=6379
VALKEY_USERNAME=
VALKEY_PASSWORD=
VALKEY_IS_SSL=false
# testing
# to get Google Drive env variables: add storage in UI and copy data from added storage here
TEST_GOOGLE_DRIVE_CLIENT_ID=

View File

@@ -10,4 +10,10 @@ DATABASE_URL=postgres://postgres:Q1234567@localhost:5437/databasus?sslmode=disab
# migrations
GOOSE_DRIVER=postgres
GOOSE_DBSTRING=postgres://postgres:Q1234567@localhost:5437/databasus?sslmode=disable
GOOSE_MIGRATION_DIR=./migrations
GOOSE_MIGRATION_DIR=./migrations
# valkey
VALKEY_HOST=127.0.0.1
VALKEY_PORT=6379
VALKEY_USERNAME=
VALKEY_PASSWORD=
VALKEY_IS_SSL=false

3
backend/.gitignore vendored
View File

@@ -17,4 +17,5 @@ ui/build/*
pgdata-for-restore/
temp/
cmd.exe
temp/
temp/
valkey-data/

View File

@@ -2,10 +2,10 @@ run:
go run cmd/main.go
test:
go test -p=1 -count=1 -failfast -timeout 10m .\internal\...
go test -p=1 -count=1 -failfast -timeout 15m ./internal/...
lint:
golangci-lint fmt && golangci-lint run
golangci-lint fmt ./cmd/... ./internal/... && golangci-lint run ./cmd/... ./internal/...
migration-create:
goose create $(name) sql

View File

@@ -15,6 +15,9 @@ import (
"databasus-backend/internal/config"
"databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
"databasus-backend/internal/features/backups/backups/backuping"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_download "databasus-backend/internal/features/backups/backups/download"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/disk"
@@ -29,6 +32,7 @@ import (
users_middleware "databasus-backend/internal/features/users/middleware"
users_services "databasus-backend/internal/features/users/services"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
cache_utils "databasus-backend/internal/util/cache"
env_utils "databasus-backend/internal/util/env"
files_utils "databasus-backend/internal/util/files"
"databasus-backend/internal/util/logger"
@@ -52,7 +56,21 @@ import (
func main() {
log := logger.GetLogger()
runMigrations(log)
cache_utils.TestCacheConnection()
if config.GetEnv().IsPrimaryNode {
err := cache_utils.ClearAllCache()
if err != nil {
log.Error("Failed to clear cache", "error", err)
os.Exit(1)
}
}
if config.GetEnv().IsPrimaryNode {
runMigrations(log)
} else {
log.Info("Skipping migrations (IS_PRIMARY_NODE is false)")
}
// create directories that used for backups and restore
err := files_utils.EnsureDirectories([]string{
@@ -96,7 +114,9 @@ func main() {
enableCors(ginApp)
setUpRoutes(ginApp)
setUpDependencies()
runBackgroundTasks(log)
mountFrontend(ginApp)
startServerWithGracefulShutdown(log, ginApp)
@@ -183,6 +203,7 @@ func setUpRoutes(r *gin.Engine) {
userController := users_controllers.GetUserController()
userController.RegisterRoutes(v1)
system_healthcheck.GetHealthcheckController().RegisterRoutes(v1)
backups.GetBackupController().RegisterPublicRoutes(v1)
// Setup auth middleware
userService := users_services.GetUserService()
@@ -218,31 +239,64 @@ func setUpDependencies() {
notifiers.SetupDependencies()
storages.SetupDependencies()
backups_config.SetupDependencies()
backups_cancellation.SetupDependencies()
}
func runBackgroundTasks(log *slog.Logger) {
log.Info("Preparing to run background tasks...")
err := files_utils.CleanFolder(config.GetEnv().TempFolder)
if err != nil {
log.Error("Failed to clean temp folder", "error", err)
// Create context that will be cancelled on shutdown
ctx, cancel := context.WithCancel(context.Background())
// Set up signal handling for graceful shutdown
quit := make(chan os.Signal, 1)
signal.Notify(quit, os.Interrupt, syscall.SIGTERM)
go func() {
<-quit
log.Info("Shutdown signal received, cancelling all background tasks")
cancel()
}()
if config.GetEnv().IsPrimaryNode {
log.Info("Starting primary node background tasks...")
err := files_utils.CleanFolder(config.GetEnv().TempFolder)
if err != nil {
log.Error("Failed to clean temp folder", "error", err)
}
go runWithPanicLogging(log, "backup background service", func() {
backuping.GetBackupsScheduler().Run(ctx)
})
go runWithPanicLogging(log, "restore background service", func() {
restores.GetRestoreBackgroundService().Run(ctx)
})
go runWithPanicLogging(log, "healthcheck attempt background service", func() {
healthcheck_attempt.GetHealthcheckAttemptBackgroundService().Run(ctx)
})
go runWithPanicLogging(log, "audit log cleanup background service", func() {
audit_logs.GetAuditLogBackgroundService().Run(ctx)
})
go runWithPanicLogging(log, "download token cleanup background service", func() {
backups_download.GetDownloadTokenBackgroundService().Run(ctx)
})
} else {
log.Info("Skipping primary node tasks as not primary node")
}
go runWithPanicLogging(log, "backup background service", func() {
backups.GetBackupBackgroundService().Run()
})
if config.GetEnv().IsBackupNode {
log.Info("Starting backup node background tasks...")
go runWithPanicLogging(log, "restore background service", func() {
restores.GetRestoreBackgroundService().Run()
})
go runWithPanicLogging(log, "healthcheck attempt background service", func() {
healthcheck_attempt.GetHealthcheckAttemptBackgroundService().Run()
})
go runWithPanicLogging(log, "audit log cleanup background service", func() {
audit_logs.GetAuditLogBackgroundService().Run()
})
go runWithPanicLogging(log, "backup node", func() {
backuping.GetBackuperNode().Run(ctx)
})
} else {
log.Info("Skipping backup node tasks as not backup node")
}
}
func runWithPanicLogging(log *slog.Logger, serviceName string, fn func()) {
@@ -285,16 +339,13 @@ func generateSwaggerDocs(log *slog.Logger) {
func runMigrations(log *slog.Logger) {
log.Info("Running database migrations...")
cmd := exec.Command("goose", "up")
cmd := exec.Command("goose", "-dir", "./migrations", "up")
cmd.Env = append(
os.Environ(),
"GOOSE_DRIVER=postgres",
"GOOSE_DBSTRING="+config.GetEnv().DatabaseDsn,
)
// Set the working directory to where migrations are located
cmd.Dir = "./migrations"
output, err := cmd.CombinedOutput()
if err != nil {
log.Error("Failed to run migrations", "error", err, "output", string(output))

View File

@@ -19,6 +19,21 @@ services:
command: -p 5437
shm_size: 10gb
# Valkey for caching
dev-valkey:
image: valkey/valkey:9.0.1-alpine
ports:
- "${VALKEY_PORT:-6379}:6379"
volumes:
- ./valkey-data:/data
container_name: dev-valkey
healthcheck:
test: ["CMD", "valkey-cli", "ping"]
interval: 10s
timeout: 5s
retries: 5
start_period: 20s
# Test MinIO container
test-minio:
image: minio/minio:latest

View File

@@ -1,6 +1,6 @@
module databasus-backend
go 1.24.4
go 1.24.9
require (
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0
@@ -25,6 +25,7 @@ require (
github.com/swaggo/files v1.0.1
github.com/swaggo/gin-swagger v1.6.0
github.com/swaggo/swag v1.16.4
github.com/valkey-io/valkey-go v1.0.70
go.mongodb.org/mongo-driver v1.17.6
golang.org/x/crypto v0.46.0
golang.org/x/time v0.14.0
@@ -269,7 +270,7 @@ require (
go.opentelemetry.io/otel/metric v1.38.0 // indirect
go.opentelemetry.io/otel/trace v1.38.0 // indirect
golang.org/x/arch v0.17.0 // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/net v0.48.0 // indirect
golang.org/x/oauth2 v0.33.0
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.39.0 // indirect

View File

@@ -539,8 +539,8 @@ github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
github.com/onsi/ginkgo/v2 v2.17.3 h1:oJcvKpIb7/8uLpDDtnQuf18xVnwKp8DTD7DQ6gTd/MU=
github.com/onsi/ginkgo/v2 v2.17.3/go.mod h1:nP2DPOQoNsQmsVyv5rDA8JkXQoCs6goXIvr/PRJ1eCc=
github.com/onsi/gomega v1.37.0 h1:CdEG8g0S133B4OswTDC/5XPSzE1OeP29QOioj2PID2Y=
github.com/onsi/gomega v1.37.0/go.mod h1:8D9+Txp43QWKhM24yyOBEdpkzN8FvJyAwecBgsU4KU0=
github.com/onsi/gomega v1.38.3 h1:eTX+W6dobAYfFeGC2PV6RwXRu/MyT+cQguijutvkpSM=
github.com/onsi/gomega v1.38.3/go.mod h1:ZCU1pkQcXDO5Sl9/VVEGlDyp+zm0m1cmeG5TOzLgdh4=
github.com/oracle/oci-go-sdk/v65 v65.104.0 h1:l9awEvzWvxmYhy/97A0hZ87pa7BncYXmcO/S8+rvgK0=
github.com/oracle/oci-go-sdk/v65 v65.104.0/go.mod h1:oB8jFGVc/7/zJ+DbleE8MzGHjhs2ioCz5stRTdZdIcY=
github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg=
@@ -660,6 +660,8 @@ github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY=
github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
github.com/unknwon/goconfig v1.0.0 h1:rS7O+CmUdli1T+oDm7fYj1MwqNWtEJfNj+FqcUHML8U=
github.com/unknwon/goconfig v1.0.0/go.mod h1:qu2ZQ/wcC/if2u32263HTVC39PeOQRSmidQk3DuDFQ8=
github.com/valkey-io/valkey-go v1.0.70 h1:mjYNT8qiazxDAJ0QNQ8twWT/YFOkOoRd40ERV2mB49Y=
github.com/valkey-io/valkey-go v1.0.70/go.mod h1:VGhZ6fs68Qrn2+OhH+6waZH27bjpgQOiLyUQyXuYK5k=
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM=
@@ -720,6 +722,8 @@ go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/arch v0.17.0 h1:4O3dfLzd+lQewptAHqjewQZQDyEdejz3VwgeYwkZneU=
golang.org/x/arch v0.17.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
@@ -818,8 +822,8 @@ golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=

View File

@@ -9,6 +9,7 @@ import (
"strings"
"sync"
"github.com/google/uuid"
"github.com/ilyakaznacheev/cleanenv"
"github.com/joho/godotenv"
)
@@ -29,6 +30,12 @@ type EnvVariables struct {
MariadbInstallDir string `env:"MARIADB_INSTALL_DIR"`
MongodbInstallDir string `env:"MONGODB_INSTALL_DIR"`
NodeID string
IsManyNodesMode bool `env:"IS_MANY_NODES_MODE"`
IsPrimaryNode bool `env:"IS_PRIMARY_NODE"`
IsBackupNode bool `env:"IS_BACKUP_NODE"`
NodeNetworkThroughputMBs int `env:"NODE_NETWORK_THROUGHPUT_MBPS"`
DataFolder string
TempFolder string
SecretKeyPath string
@@ -79,6 +86,13 @@ type EnvVariables struct {
TestMongodb70Port string `env:"TEST_MONGODB_70_PORT"`
TestMongodb82Port string `env:"TEST_MONGODB_82_PORT"`
// Valkey
ValkeyHost string `env:"VALKEY_HOST" required:"true"`
ValkeyPort string `env:"VALKEY_PORT" required:"true"`
ValkeyUsername string `env:"VALKEY_USERNAME"`
ValkeyPassword string `env:"VALKEY_PASSWORD"`
ValkeyIsSsl bool `env:"VALKEY_IS_SSL" required:"true"`
// oauth
GitHubClientID string `env:"GITHUB_CLIENT_ID"`
GitHubClientSecret string `env:"GITHUB_CLIENT_SECRET"`
@@ -189,6 +203,26 @@ func loadEnvVariables() {
env.MongodbInstallDir = filepath.Join(backendRoot, "tools", "mongodb")
tools.VerifyMongodbInstallation(log, env.EnvMode, env.MongodbInstallDir)
env.NodeID = uuid.New().String()
if env.NodeNetworkThroughputMBs == 0 {
env.NodeNetworkThroughputMBs = 125 // 1 Gbit/s
}
if !env.IsManyNodesMode {
env.IsPrimaryNode = true
env.IsBackupNode = true
}
// Valkey
if env.ValkeyHost == "" {
log.Error("VALKEY_HOST is empty")
os.Exit(1)
}
if env.ValkeyPort == "" {
log.Error("VALKEY_PORT is empty")
os.Exit(1)
}
// Store the data and temp folders one level below the root
// (projectRoot/databasus-data -> /databasus-data)
env.DataFolder = filepath.Join(filepath.Dir(backendRoot), "databasus-data", "backups")

View File

@@ -1,7 +1,7 @@
package audit_logs
import (
"databasus-backend/internal/config"
"context"
"log/slog"
"time"
)
@@ -11,23 +11,25 @@ type AuditLogBackgroundService struct {
logger *slog.Logger
}
func (s *AuditLogBackgroundService) Run() {
func (s *AuditLogBackgroundService) Run(ctx context.Context) {
s.logger.Info("Starting audit log cleanup background service")
if config.IsShouldShutdown() {
if ctx.Err() != nil {
return
}
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for {
if config.IsShouldShutdown() {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.cleanOldAuditLogs(); err != nil {
s.logger.Error("Failed to clean old audit logs", "error", err)
}
}
if err := s.cleanOldAuditLogs(); err != nil {
s.logger.Error("Failed to clean old audit logs", "error", err)
}
time.Sleep(1 * time.Hour)
}
}

View File

@@ -1,254 +0,0 @@
package backups
import (
"databasus-backend/internal/config"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/storages"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/period"
"log/slog"
"time"
)
type BackupBackgroundService struct {
backupService *BackupService
backupRepository *BackupRepository
backupConfigService *backups_config.BackupConfigService
storageService *storages.StorageService
lastBackupTime time.Time
logger *slog.Logger
}
func (s *BackupBackgroundService) Run() {
s.lastBackupTime = time.Now().UTC()
if err := s.failBackupsInProgress(); err != nil {
s.logger.Error("Failed to fail backups in progress", "error", err)
panic(err)
}
if config.IsShouldShutdown() {
return
}
for {
if config.IsShouldShutdown() {
return
}
if err := s.cleanOldBackups(); err != nil {
s.logger.Error("Failed to clean old backups", "error", err)
}
if err := s.runPendingBackups(); err != nil {
s.logger.Error("Failed to run pending backups", "error", err)
}
s.lastBackupTime = time.Now().UTC()
time.Sleep(1 * time.Minute)
}
}
func (s *BackupBackgroundService) IsBackupsWorkerRunning() bool {
// if last backup time is more than 5 minutes ago, return false
return s.lastBackupTime.After(time.Now().UTC().Add(-5 * time.Minute))
}
func (s *BackupBackgroundService) failBackupsInProgress() error {
backupsInProgress, err := s.backupRepository.FindByStatus(BackupStatusInProgress)
if err != nil {
return err
}
for _, backup := range backupsInProgress {
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(backup.DatabaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
continue
}
failMessage := "Backup failed due to application restart"
backup.FailMessage = &failMessage
backup.Status = BackupStatusFailed
backup.BackupSizeMb = 0
s.backupService.SendBackupNotification(
backupConfig,
backup,
backups_config.NotificationBackupFailed,
&failMessage,
)
if err := s.backupRepository.Save(backup); err != nil {
return err
}
}
return nil
}
func (s *BackupBackgroundService) cleanOldBackups() error {
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
backupStorePeriod := backupConfig.StorePeriod
if backupStorePeriod == period.PeriodForever {
continue
}
storeDuration := backupStorePeriod.ToDuration()
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
oldBackups, err := s.backupRepository.FindBackupsBeforeDate(
backupConfig.DatabaseID,
dateBeforeBackupsShouldBeDeleted,
)
if err != nil {
s.logger.Error(
"Failed to find old backups for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
continue
}
for _, backup := range oldBackups {
storage, err := s.storageService.GetStorageByID(backup.StorageID)
if err != nil {
s.logger.Error(
"Failed to get storage by ID",
"storageId",
backup.StorageID,
"error",
err,
)
continue
}
encryptor := encryption.GetFieldEncryptor()
err = storage.DeleteFile(encryptor, backup.ID)
if err != nil {
s.logger.Error("Failed to delete backup file", "backupId", backup.ID, "error", err)
}
if err := s.backupRepository.DeleteByID(backup.ID); err != nil {
s.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
continue
}
s.logger.Info(
"Deleted old backup",
"backupId",
backup.ID,
"databaseId",
backupConfig.DatabaseID,
)
}
}
return nil
}
func (s *BackupBackgroundService) runPendingBackups() error {
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
if backupConfig.BackupInterval == nil {
continue
}
lastBackup, err := s.backupRepository.FindLastByDatabaseID(backupConfig.DatabaseID)
if err != nil {
s.logger.Error(
"Failed to get last backup for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
continue
}
var lastBackupTime *time.Time
if lastBackup != nil {
lastBackupTime = &lastBackup.CreatedAt
}
remainedBackupTryCount := s.GetRemainedBackupTryCount(lastBackup)
if backupConfig.BackupInterval.ShouldTriggerBackup(time.Now().UTC(), lastBackupTime) ||
remainedBackupTryCount > 0 {
s.logger.Info(
"Triggering scheduled backup",
"databaseId",
backupConfig.DatabaseID,
"intervalType",
backupConfig.BackupInterval.Interval,
)
go s.backupService.MakeBackup(backupConfig.DatabaseID, remainedBackupTryCount == 1)
s.logger.Info(
"Successfully triggered scheduled backup",
"databaseId",
backupConfig.DatabaseID,
)
}
}
return nil
}
// GetRemainedBackupTryCount returns the number of remaining backup tries for a given backup.
// If the backup is not failed or the backup config does not allow retries, it returns 0.
// If the backup is failed and the backup config allows retries, it returns the number of remaining tries.
// If the backup is failed and the backup config does not allow retries, it returns 0.
func (s *BackupBackgroundService) GetRemainedBackupTryCount(lastBackup *Backup) int {
if lastBackup == nil {
return 0
}
if lastBackup.Status != BackupStatusFailed {
return 0
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(lastBackup.DatabaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
return 0
}
if !backupConfig.IsRetryIfFailed {
return 0
}
maxFailedTriesCount := backupConfig.MaxFailedTriesCount
lastBackups, err := s.backupRepository.FindByDatabaseIDWithLimit(
lastBackup.DatabaseID,
maxFailedTriesCount,
)
if err != nil {
s.logger.Error("Failed to find last backups by database ID", "error", err)
return 0
}
lastFailedBackups := make([]*Backup, 0)
for _, backup := range lastBackups {
if backup.Status == BackupStatusFailed {
lastFailedBackups = append(lastFailedBackups, backup)
}
}
return maxFailedTriesCount - len(lastFailedBackups)
}

View File

@@ -1,389 +0,0 @@
package backups
import (
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/util/period"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
// setup data
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
// cleanup backups first
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Enable backups for the database
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// add old backup
backupRepository.Save(&Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusCompleted,
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
})
GetBackupBackgroundService().runPendingBackups()
// Wait for backup to complete (runs in goroutine)
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
// assertions
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 2)
}
func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
// setup data
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
// cleanup backups first
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Enable backups for the database
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// add recent backup (1 hour ago)
backupRepository.Save(&Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusCompleted,
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
})
GetBackupBackgroundService().runPendingBackups()
time.Sleep(100 * time.Millisecond)
// assertions
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1) // Should still be 1 backup, no new backup created
}
func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T) {
// setup data
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
// cleanup backups first
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Enable backups for the database with retries disabled
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
backupConfig.IsRetryIfFailed = false
backupConfig.MaxFailedTriesCount = 0
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// add failed backup
failMessage := "backup failed"
backupRepository.Save(&Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusFailed,
FailMessage: &failMessage,
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
})
GetBackupBackgroundService().runPendingBackups()
time.Sleep(100 * time.Millisecond)
// assertions
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1) // Should still be 1 backup, no retry attempted
}
func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
// setup data
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
// cleanup backups first
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Enable backups for the database with retries enabled
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
backupConfig.IsRetryIfFailed = true
backupConfig.MaxFailedTriesCount = 3
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// add failed backup
failMessage := "backup failed"
backupRepository.Save(&Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusFailed,
FailMessage: &failMessage,
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
})
GetBackupBackgroundService().runPendingBackups()
// Wait for backup to complete (runs in goroutine)
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
// assertions
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 2) // Should have 2 backups, retry was attempted
}
func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *testing.T) {
// setup data
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
// cleanup backups first
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Enable backups for the database with retries enabled
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
backupConfig.IsRetryIfFailed = true
backupConfig.MaxFailedTriesCount = 3
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
failMessage := "backup failed"
for i := 0; i < 3; i++ {
backupRepository.Save(&Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusFailed,
FailMessage: &failMessage,
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
})
}
GetBackupBackgroundService().runPendingBackups()
time.Sleep(100 * time.Millisecond)
// assertions
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 3) // Should have 3 backups, not more than max
}
func Test_MakeBackgroundBackupWhenBakupsDisabled_BackupSkipped(t *testing.T) {
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = false
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// add old backup that would trigger new backup if enabled
backupRepository.Save(&Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusCompleted,
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
})
GetBackupBackgroundService().runPendingBackups()
time.Sleep(100 * time.Millisecond)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
}

View File

@@ -1,60 +0,0 @@
package backups
import (
"context"
"sync"
"github.com/google/uuid"
)
type BackupContextManager struct {
mu sync.RWMutex
cancelFuncs map[uuid.UUID]context.CancelFunc
cancelledBackups map[uuid.UUID]bool
}
func NewBackupContextManager() *BackupContextManager {
return &BackupContextManager{
cancelFuncs: make(map[uuid.UUID]context.CancelFunc),
cancelledBackups: make(map[uuid.UUID]bool),
}
}
func (m *BackupContextManager) RegisterBackup(backupID uuid.UUID, cancelFunc context.CancelFunc) {
m.mu.Lock()
defer m.mu.Unlock()
m.cancelFuncs[backupID] = cancelFunc
delete(m.cancelledBackups, backupID)
}
func (m *BackupContextManager) CancelBackup(backupID uuid.UUID) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.cancelledBackups[backupID] {
return nil
}
cancelFunc, exists := m.cancelFuncs[backupID]
if exists {
cancelFunc()
delete(m.cancelFuncs, backupID)
}
m.cancelledBackups[backupID] = true
return nil
}
func (m *BackupContextManager) IsCancelled(backupID uuid.UUID) bool {
m.mu.RLock()
defer m.mu.RUnlock()
return m.cancelledBackups[backupID]
}
func (m *BackupContextManager) UnregisterBackup(backupID uuid.UUID) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.cancelFuncs, backupID)
delete(m.cancelledBackups, backupID)
}

View File

@@ -0,0 +1,344 @@
package backuping
import (
"context"
"databasus-backend/internal/config"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/storages"
workspaces_services "databasus-backend/internal/features/workspaces/services"
util_encryption "databasus-backend/internal/util/encryption"
"errors"
"fmt"
"log/slog"
"slices"
"strings"
"time"
"github.com/google/uuid"
)
type BackuperNode struct {
databaseService *databases.DatabaseService
fieldEncryptor util_encryption.FieldEncryptor
workspaceService *workspaces_services.WorkspaceService
backupRepository *backups_core.BackupRepository
backupConfigService *backups_config.BackupConfigService
storageService *storages.StorageService
notificationSender backups_core.NotificationSender
backupCancelManager *backups_cancellation.BackupCancelManager
nodesRegistry *BackupNodesRegistry
logger *slog.Logger
createBackupUseCase backups_core.CreateBackupUsecase
nodeID uuid.UUID
lastHeartbeat time.Time
}
func (n *BackuperNode) Run(ctx context.Context) {
n.lastHeartbeat = time.Now().UTC()
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
backupNode := BackupNode{
ID: n.nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: time.Now().UTC(),
}
if err := n.nodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
n.logger.Error("Failed to register node in registry", "error", err)
panic(err)
}
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
n.MakeBackup(backupID, isCallNotifier)
if err := n.nodesRegistry.PublishBackupCompletion(n.nodeID.String(), backupID); err != nil {
n.logger.Error(
"Failed to publish backup completion",
"error",
err,
"backupID",
backupID,
)
}
}
if err := n.nodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID.String(), backupHandler); err != nil {
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
panic(err)
}
defer func() {
if err := n.nodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
}
}()
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
for {
select {
case <-ctx.Done():
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
if err := n.nodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
n.logger.Error("Failed to unregister node from registry", "error", err)
}
return
case <-ticker.C:
n.sendHeartbeat(&backupNode)
}
}
}
func (n *BackuperNode) IsBackuperRunning() bool {
return n.lastHeartbeat.After(time.Now().UTC().Add(-5 * time.Minute))
}
func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
backup, err := n.backupRepository.FindByID(backupID)
if err != nil {
n.logger.Error("Failed to get backup by ID", "backupId", backupID, "error", err)
return
}
databaseID := backup.DatabaseID
database, err := n.databaseService.GetDatabaseByID(databaseID)
if err != nil {
n.logger.Error("Failed to get database by ID", "databaseId", databaseID, "error", err)
return
}
backupConfig, err := n.backupConfigService.GetBackupConfigByDbId(databaseID)
if err != nil {
n.logger.Error("Failed to get backup config by database ID", "error", err)
return
}
if backupConfig.StorageID == nil {
n.logger.Error("Backup config storage ID is not defined")
return
}
storage, err := n.storageService.GetStorageByID(*backupConfig.StorageID)
if err != nil {
n.logger.Error("Failed to get storage by ID", "error", err)
return
}
start := time.Now().UTC()
backupProgressListener := func(
completedMBs float64,
) {
backup.BackupSizeMb = completedMBs
backup.BackupDurationMs = time.Since(start).Milliseconds()
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to update backup progress", "error", err)
}
}
ctx, cancel := context.WithCancel(context.Background())
n.backupCancelManager.RegisterBackup(backup.ID, cancel)
defer n.backupCancelManager.UnregisterBackup(backup.ID)
backupMetadata, err := n.createBackupUseCase.Execute(
ctx,
backup.ID,
backupConfig,
database,
storage,
backupProgressListener,
)
if err != nil {
errMsg := err.Error()
// Check if backup was cancelled (not due to shutdown)
isCancelled := strings.Contains(errMsg, "backup cancelled") ||
strings.Contains(errMsg, "context canceled") ||
errors.Is(err, context.Canceled)
isShutdown := strings.Contains(errMsg, "shutdown")
if isCancelled && !isShutdown {
backup.Status = backups_core.BackupStatusCanceled
backup.BackupDurationMs = time.Since(start).Milliseconds()
backup.BackupSizeMb = 0
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to save cancelled backup", "error", err)
}
// Delete partial backup from storage
storage, storageErr := n.storageService.GetStorageByID(backup.StorageID)
if storageErr == nil {
if deleteErr := storage.DeleteFile(n.fieldEncryptor, backup.ID); deleteErr != nil {
n.logger.Error(
"Failed to delete partial backup file",
"backupId",
backup.ID,
"error",
deleteErr,
)
}
}
return
}
backup.FailMessage = &errMsg
backup.Status = backups_core.BackupStatusFailed
backup.BackupDurationMs = time.Since(start).Milliseconds()
backup.BackupSizeMb = 0
if updateErr := n.databaseService.SetBackupError(databaseID, errMsg); updateErr != nil {
n.logger.Error(
"Failed to update database last backup time",
"databaseId",
databaseID,
"error",
updateErr,
)
}
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to save backup", "error", err)
}
n.SendBackupNotification(
backupConfig,
backup,
backups_config.NotificationBackupFailed,
&errMsg,
)
return
}
backup.Status = backups_core.BackupStatusCompleted
backup.BackupDurationMs = time.Since(start).Milliseconds()
// Update backup with encryption metadata if provided
if backupMetadata != nil {
backup.EncryptionSalt = backupMetadata.EncryptionSalt
backup.EncryptionIV = backupMetadata.EncryptionIV
backup.Encryption = backupMetadata.Encryption
}
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to save backup", "error", err)
return
}
// Update database last backup time
now := time.Now().UTC()
if updateErr := n.databaseService.SetLastBackupTime(databaseID, now); updateErr != nil {
n.logger.Error(
"Failed to update database last backup time",
"databaseId",
databaseID,
"error",
updateErr,
)
}
if backup.Status != backups_core.BackupStatusCompleted && !isCallNotifier {
return
}
n.SendBackupNotification(
backupConfig,
backup,
backups_config.NotificationBackupSuccess,
nil,
)
}
func (n *BackuperNode) SendBackupNotification(
backupConfig *backups_config.BackupConfig,
backup *backups_core.Backup,
notificationType backups_config.BackupNotificationType,
errorMessage *string,
) {
database, err := n.databaseService.GetDatabaseByID(backupConfig.DatabaseID)
if err != nil {
return
}
workspace, err := n.workspaceService.GetWorkspaceByID(*database.WorkspaceID)
if err != nil {
return
}
for _, notifier := range database.Notifiers {
if !slices.Contains(
backupConfig.SendNotificationsOn,
notificationType,
) {
continue
}
title := ""
switch notificationType {
case backups_config.NotificationBackupFailed:
title = fmt.Sprintf(
"❌ Backup failed for database \"%s\" (workspace \"%s\")",
database.Name,
workspace.Name,
)
case backups_config.NotificationBackupSuccess:
title = fmt.Sprintf(
"✅ Backup completed for database \"%s\" (workspace \"%s\")",
database.Name,
workspace.Name,
)
}
message := ""
if errorMessage != nil {
message = *errorMessage
} else {
// Format size conditionally
var sizeStr string
if backup.BackupSizeMb < 1024 {
sizeStr = fmt.Sprintf("%.2f MB", backup.BackupSizeMb)
} else {
sizeGB := backup.BackupSizeMb / 1024
sizeStr = fmt.Sprintf("%.2f GB", sizeGB)
}
// Format duration as "0m 0s 0ms"
totalMs := backup.BackupDurationMs
minutes := totalMs / (1000 * 60)
seconds := (totalMs % (1000 * 60)) / 1000
durationStr := fmt.Sprintf("%dm %ds", minutes, seconds)
message = fmt.Sprintf(
"Backup completed successfully in %s.\nCompressed backup size: %s",
durationStr,
sizeStr,
)
}
n.notificationSender.SendNotification(
&notifier,
title,
message,
)
}
}
func (n *BackuperNode) sendHeartbeat(backupNode *BackupNode) {
n.lastHeartbeat = time.Now().UTC()
backupNode.LastHeartbeat = time.Now().UTC()
if err := n.nodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *backupNode); err != nil {
n.logger.Error("Failed to send heartbeat", "error", err)
}
}

View File

@@ -1,4 +1,4 @@
package backups
package backuping
import (
"context"
@@ -8,17 +8,15 @@ import (
"time"
common "databasus-backend/internal/features/backups/backups/common"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_services "databasus-backend/internal/features/workspaces/services"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
cache_utils "databasus-backend/internal/util/cache"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
@@ -26,6 +24,7 @@ import (
)
func Test_BackupExecuted_NotificationSent(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
@@ -50,22 +49,19 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
t.Run("BackupFailed_FailNotificationSent", func(t *testing.T) {
mockNotificationSender := &MockNotificationSender{}
backupService := &BackupService{
databases.GetDatabaseService(),
storages.GetStorageService(),
backupRepository,
notifiers.GetNotifierService(),
mockNotificationSender,
backups_config.GetBackupConfigService(),
encryption_secrets.GetSecretKeyService(),
encryption.GetFieldEncryptor(),
&CreateFailedBackupUsecase{},
logger.GetLogger(),
[]BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
nil,
NewBackupContextManager(),
backuperNode := CreateTestBackuperNode()
backuperNode.notificationSender = mockNotificationSender
backuperNode.createBackupUseCase = &CreateFailedBackupUsecase{}
// Create a backup record directly that will be looked up by MakeBackup
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err := backupRepository.Save(backup)
assert.NoError(t, err)
// Set up expectations
mockNotificationSender.On("SendNotification",
@@ -78,7 +74,7 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
}),
).Once()
backupService.MakeBackup(database.ID, true)
backuperNode.MakeBackup(backup.ID, true)
// Verify all expectations were met
mockNotificationSender.AssertExpectations(t)
@@ -86,6 +82,19 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
t.Run("BackupSuccess_SuccessNotificationSent", func(t *testing.T) {
mockNotificationSender := &MockNotificationSender{}
backuperNode := CreateTestBackuperNode()
backuperNode.notificationSender = mockNotificationSender
backuperNode.createBackupUseCase = &CreateSuccessBackupUsecase{}
// Create a backup record directly that will be looked up by MakeBackup
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err := backupRepository.Save(backup)
assert.NoError(t, err)
// Set up expectations
mockNotificationSender.On("SendNotification",
@@ -98,24 +107,7 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
}),
).Once()
backupService := &BackupService{
databases.GetDatabaseService(),
storages.GetStorageService(),
backupRepository,
notifiers.GetNotifierService(),
mockNotificationSender,
backups_config.GetBackupConfigService(),
encryption_secrets.GetSecretKeyService(),
encryption.GetFieldEncryptor(),
&CreateSuccessBackupUsecase{},
logger.GetLogger(),
[]BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
nil,
NewBackupContextManager(),
}
backupService.MakeBackup(database.ID, true)
backuperNode.MakeBackup(backup.ID, true)
// Verify all expectations were met
mockNotificationSender.AssertExpectations(t)
@@ -123,22 +115,19 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
t.Run("BackupSuccess_VerifyNotificationContent", func(t *testing.T) {
mockNotificationSender := &MockNotificationSender{}
backupService := &BackupService{
databases.GetDatabaseService(),
storages.GetStorageService(),
backupRepository,
notifiers.GetNotifierService(),
mockNotificationSender,
backups_config.GetBackupConfigService(),
encryption_secrets.GetSecretKeyService(),
encryption.GetFieldEncryptor(),
&CreateSuccessBackupUsecase{},
logger.GetLogger(),
[]BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
nil,
NewBackupContextManager(),
backuperNode := CreateTestBackuperNode()
backuperNode.notificationSender = mockNotificationSender
backuperNode.createBackupUseCase = &CreateSuccessBackupUsecase{}
// Create a backup record directly that will be looked up by MakeBackup
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err := backupRepository.Save(backup)
assert.NoError(t, err)
// capture arguments
var capturedNotifier *notifiers.Notifier
@@ -155,7 +144,7 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
capturedMessage = args.Get(2).(string)
}).Once()
backupService.MakeBackup(database.ID, true)
backuperNode.MakeBackup(backup.ID, true)
// Verify expectations were met
mockNotificationSender.AssertExpectations(t)

View File

@@ -0,0 +1,77 @@
package backuping
import (
"databasus-backend/internal/config"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/usecases"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
workspaces_services "databasus-backend/internal/features/workspaces/services"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
"time"
"github.com/google/uuid"
)
var backupRepository = &backups_core.BackupRepository{}
var backupCancelManager = backups_cancellation.GetBackupCancelManager()
var nodesRegistry = &BackupNodesRegistry{
cache_utils.GetValkeyClient(),
logger.GetLogger(),
cache_utils.DefaultCacheTimeout,
cache_utils.NewPubSubManager(),
cache_utils.NewPubSubManager(),
}
func getNodeID() uuid.UUID {
nodeIDStr := config.GetEnv().NodeID
nodeID, err := uuid.Parse(nodeIDStr)
if err != nil {
logger.GetLogger().Error("Failed to parse node ID from config", "error", err)
panic(err)
}
return nodeID
}
var backuperNode = &BackuperNode{
databases.GetDatabaseService(),
encryption.GetFieldEncryptor(),
workspaces_services.GetWorkspaceService(),
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
notifiers.GetNotifierService(),
backupCancelManager,
nodesRegistry,
logger.GetLogger(),
usecases.GetCreateBackupUsecase(),
getNodeID(),
time.Time{},
}
var backupsScheduler = &BackupsScheduler{
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
backupCancelManager,
nodesRegistry,
time.Now().UTC(),
logger.GetLogger(),
make(map[uuid.UUID]BackupToNodeRelation),
backuperNode,
}
func GetBackupsScheduler() *BackupsScheduler {
return backupsScheduler
}
func GetBackuperNode() *BackuperNode {
return backuperNode
}

View File

@@ -0,0 +1,34 @@
package backuping
import (
"time"
"github.com/google/uuid"
)
type BackupNode struct {
ID uuid.UUID `json:"id"`
ThroughputMBs int `json:"throughputMBs"`
LastHeartbeat time.Time `json:"lastHeartbeat"`
}
type BackupNodeStats struct {
ID uuid.UUID `json:"id"`
ActiveBackups int `json:"activeBackups"`
}
type BackupSubmitMessage struct {
NodeID string `json:"nodeId"`
BackupID string `json:"backupId"`
IsCallNotifier bool `json:"isCallNotifier"`
}
type BackupCompletionMessage struct {
NodeID string `json:"nodeId"`
BackupID string `json:"backupId"`
}
type BackupToNodeRelation struct {
NodeID uuid.UUID `json:"nodeId"`
BackupsIDs []uuid.UUID `json:"backupsIds"`
}

View File

@@ -1,4 +1,4 @@
package backups
package backuping
import (
"databasus-backend/internal/features/notifiers"

View File

@@ -0,0 +1,448 @@
package backuping
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"time"
cache_utils "databasus-backend/internal/util/cache"
"github.com/google/uuid"
"github.com/valkey-io/valkey-go"
)
const (
nodeInfoKeyPrefix = "node:"
nodeInfoKeySuffix = ":info"
nodeActiveBackupsPrefix = "node:"
nodeActiveBackupsSuffix = ":active_backups"
backupSubmitChannel = "backup:submit"
backupCompletionChannel = "backup:completion"
)
// BackupNodesRegistry helps to sync backups scheduler and backup nodes.
// Features:
// - Track node availability and load level
// - Assign from scheduler to node backups needed to be processed
// - Notify scheduler from node about backup completion
type BackupNodesRegistry struct {
client valkey.Client
logger *slog.Logger
timeout time.Duration
pubsubBackups *cache_utils.PubSubManager
pubsubCompletions *cache_utils.PubSubManager
}
func (r *BackupNodesRegistry) GetAvailableNodes() ([]BackupNode, error) {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
var allKeys []string
cursor := uint64(0)
pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix
for {
result := r.client.Do(
ctx,
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(),
)
if result.Error() != nil {
return nil, fmt.Errorf("failed to scan node keys: %w", result.Error())
}
scanResult, err := result.AsScanEntry()
if err != nil {
return nil, fmt.Errorf("failed to parse scan result: %w", err)
}
allKeys = append(allKeys, scanResult.Elements...)
cursor = scanResult.Cursor
if cursor == 0 {
break
}
}
if len(allKeys) == 0 {
return []BackupNode{}, nil
}
keyDataMap, err := r.pipelineGetKeys(allKeys)
if err != nil {
return nil, fmt.Errorf("failed to pipeline get node keys: %w", err)
}
var nodes []BackupNode
for key, data := range keyDataMap {
var node BackupNode
if err := json.Unmarshal(data, &node); err != nil {
r.logger.Warn("Failed to unmarshal node data", "key", key, "error", err)
continue
}
nodes = append(nodes, node)
}
return nodes, nil
}
func (r *BackupNodesRegistry) GetBackupNodesStats() ([]BackupNodeStats, error) {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
var allKeys []string
cursor := uint64(0)
pattern := nodeActiveBackupsPrefix + "*" + nodeActiveBackupsSuffix
for {
result := r.client.Do(
ctx,
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(100).Build(),
)
if result.Error() != nil {
return nil, fmt.Errorf("failed to scan active backups keys: %w", result.Error())
}
scanResult, err := result.AsScanEntry()
if err != nil {
return nil, fmt.Errorf("failed to parse scan result: %w", err)
}
allKeys = append(allKeys, scanResult.Elements...)
cursor = scanResult.Cursor
if cursor == 0 {
break
}
}
if len(allKeys) == 0 {
return []BackupNodeStats{}, nil
}
keyDataMap, err := r.pipelineGetKeys(allKeys)
if err != nil {
return nil, fmt.Errorf("failed to pipeline get active backups keys: %w", err)
}
var stats []BackupNodeStats
for key, data := range keyDataMap {
nodeID := r.extractNodeIDFromKey(key, nodeActiveBackupsPrefix, nodeActiveBackupsSuffix)
count, err := r.parseIntFromBytes(data)
if err != nil {
r.logger.Warn("Failed to parse active backups count", "key", key, "error", err)
continue
}
stat := BackupNodeStats{
ID: nodeID,
ActiveBackups: int(count),
}
stats = append(stats, stat)
}
return stats, nil
}
func (r *BackupNodesRegistry) IncrementBackupsInProgress(nodeID string) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID, nodeActiveBackupsSuffix)
result := r.client.Do(ctx, r.client.B().Incr().Key(key).Build())
if result.Error() != nil {
return fmt.Errorf(
"failed to increment backups in progress for node %s: %w",
nodeID,
result.Error(),
)
}
return nil
}
func (r *BackupNodesRegistry) DecrementBackupsInProgress(nodeID string) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID, nodeActiveBackupsSuffix)
result := r.client.Do(ctx, r.client.B().Decr().Key(key).Build())
if result.Error() != nil {
return fmt.Errorf(
"failed to decrement backups in progress for node %s: %w",
nodeID,
result.Error(),
)
}
newValue, err := result.AsInt64()
if err != nil {
return fmt.Errorf("failed to parse decremented value for node %s: %w", nodeID, err)
}
if newValue < 0 {
setCtx, setCancel := context.WithTimeout(context.Background(), r.timeout)
r.client.Do(setCtx, r.client.B().Set().Key(key).Value("0").Build())
setCancel()
r.logger.Warn("Active backups counter went below 0, reset to 0", "nodeID", nodeID)
}
return nil
}
func (r *BackupNodesRegistry) HearthbeatNodeInRegistry(now time.Time, backupNode BackupNode) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
backupNode.LastHeartbeat = now
data, err := json.Marshal(backupNode)
if err != nil {
return fmt.Errorf("failed to marshal backup node: %w", err)
}
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, backupNode.ID.String(), nodeInfoKeySuffix)
result := r.client.Do(
ctx,
r.client.B().Set().Key(key).Value(string(data)).Build(),
)
if result.Error() != nil {
return fmt.Errorf("failed to register node %s: %w", backupNode.ID, result.Error())
}
return nil
}
func (r *BackupNodesRegistry) UnregisterNodeFromRegistry(backupNode BackupNode) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, backupNode.ID.String(), nodeInfoKeySuffix)
counterKey := fmt.Sprintf(
"%s%s%s",
nodeActiveBackupsPrefix,
backupNode.ID.String(),
nodeActiveBackupsSuffix,
)
result := r.client.Do(
ctx,
r.client.B().Del().Key(infoKey, counterKey).Build(),
)
if result.Error() != nil {
return fmt.Errorf("failed to unregister node %s: %w", backupNode.ID, result.Error())
}
r.logger.Info("Unregistered node from registry", "nodeID", backupNode.ID)
return nil
}
func (r *BackupNodesRegistry) AssignBackupToNode(
targetNodeID string,
backupID uuid.UUID,
isCallNotifier bool,
) error {
ctx := context.Background()
message := BackupSubmitMessage{
NodeID: targetNodeID,
BackupID: backupID.String(),
IsCallNotifier: isCallNotifier,
}
messageJSON, err := json.Marshal(message)
if err != nil {
return fmt.Errorf("failed to marshal backup submit message: %w", err)
}
err = r.pubsubBackups.Publish(ctx, backupSubmitChannel, string(messageJSON))
if err != nil {
return fmt.Errorf("failed to publish backup submit message: %w", err)
}
return nil
}
func (r *BackupNodesRegistry) SubscribeNodeForBackupsAssignment(
nodeID string,
handler func(backupID uuid.UUID, isCallNotifier bool),
) error {
ctx := context.Background()
wrappedHandler := func(message string) {
var msg BackupSubmitMessage
if err := json.Unmarshal([]byte(message), &msg); err != nil {
r.logger.Warn("Failed to unmarshal backup submit message", "error", err)
return
}
if msg.NodeID != nodeID {
return
}
backupID, err := uuid.Parse(msg.BackupID)
if err != nil {
r.logger.Warn(
"Failed to parse backup ID from message",
"backupId",
msg.BackupID,
"error",
err,
)
return
}
handler(backupID, msg.IsCallNotifier)
}
err := r.pubsubBackups.Subscribe(ctx, backupSubmitChannel, wrappedHandler)
if err != nil {
return fmt.Errorf("failed to subscribe to backup submit channel: %w", err)
}
r.logger.Info("Subscribed to backup submit channel", "nodeID", nodeID)
return nil
}
func (r *BackupNodesRegistry) UnsubscribeNodeForBackupsAssignments() error {
err := r.pubsubBackups.Close()
if err != nil {
return fmt.Errorf("failed to unsubscribe from backup submit channel: %w", err)
}
r.logger.Info("Unsubscribed from backup submit channel")
return nil
}
func (r *BackupNodesRegistry) PublishBackupCompletion(nodeID string, backupID uuid.UUID) error {
ctx := context.Background()
message := BackupCompletionMessage{
NodeID: nodeID,
BackupID: backupID.String(),
}
messageJSON, err := json.Marshal(message)
if err != nil {
return fmt.Errorf("failed to marshal backup completion message: %w", err)
}
err = r.pubsubCompletions.Publish(ctx, backupCompletionChannel, string(messageJSON))
if err != nil {
return fmt.Errorf("failed to publish backup completion message: %w", err)
}
return nil
}
func (r *BackupNodesRegistry) SubscribeForBackupsCompletions(
handler func(nodeID string, backupID uuid.UUID),
) error {
ctx := context.Background()
wrappedHandler := func(message string) {
var msg BackupCompletionMessage
if err := json.Unmarshal([]byte(message), &msg); err != nil {
r.logger.Warn("Failed to unmarshal backup completion message", "error", err)
return
}
backupID, err := uuid.Parse(msg.BackupID)
if err != nil {
r.logger.Warn(
"Failed to parse backup ID from completion message",
"backupId",
msg.BackupID,
"error",
err,
)
return
}
handler(msg.NodeID, backupID)
}
err := r.pubsubCompletions.Subscribe(ctx, backupCompletionChannel, wrappedHandler)
if err != nil {
return fmt.Errorf("failed to subscribe to backup completion channel: %w", err)
}
r.logger.Info("Subscribed to backup completion channel")
return nil
}
func (r *BackupNodesRegistry) UnsubscribeForBackupsCompletions() error {
err := r.pubsubCompletions.Close()
if err != nil {
return fmt.Errorf("failed to unsubscribe from backup completion channel: %w", err)
}
r.logger.Info("Unsubscribed from backup completion channel")
return nil
}
func (r *BackupNodesRegistry) extractNodeIDFromKey(key, prefix, suffix string) uuid.UUID {
nodeIDStr := strings.TrimPrefix(key, prefix)
nodeIDStr = strings.TrimSuffix(nodeIDStr, suffix)
nodeID, err := uuid.Parse(nodeIDStr)
if err != nil {
r.logger.Warn("Failed to parse node ID from key", "key", key, "error", err)
return uuid.Nil
}
return nodeID
}
func (r *BackupNodesRegistry) pipelineGetKeys(keys []string) (map[string][]byte, error) {
if len(keys) == 0 {
return make(map[string][]byte), nil
}
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
commands := make([]valkey.Completed, 0, len(keys))
for _, key := range keys {
commands = append(commands, r.client.B().Get().Key(key).Build())
}
results := r.client.DoMulti(ctx, commands...)
keyDataMap := make(map[string][]byte, len(keys))
for i, result := range results {
if result.Error() != nil {
r.logger.Warn("Failed to get key in pipeline", "key", keys[i], "error", result.Error())
continue
}
data, err := result.AsBytes()
if err != nil {
r.logger.Warn("Failed to parse key data in pipeline", "key", keys[i], "error", err)
continue
}
keyDataMap[keys[i]] = data
}
return keyDataMap, nil
}
func (r *BackupNodesRegistry) parseIntFromBytes(data []byte) (int64, error) {
str := string(data)
var count int64
_, err := fmt.Sscanf(str, "%d", &count)
if err != nil {
return 0, fmt.Errorf("failed to parse integer from bytes: %w", err)
}
return count, nil
}

View File

@@ -0,0 +1,904 @@
package backuping
import (
"context"
"testing"
"time"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/logger"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_HearthbeatNodeInRegistry_RegistersNodeWithTTL(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
defer cleanupTestNode(registry, node)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 1)
assert.Equal(t, node.ID, nodes[0].ID)
assert.Equal(t, node.ThroughputMBs, nodes[0].ThroughputMBs)
}
func Test_UnregisterNodeFromRegistry_RemovesNodeAndCounter(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
err = registry.UnregisterNodeFromRegistry(node)
assert.NoError(t, err)
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.Empty(t, nodes)
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Empty(t, stats)
}
func Test_GetAvailableNodes_ReturnsAllRegisteredNodes(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node1 := createTestBackupNode()
node2 := createTestBackupNode()
node3 := createTestBackupNode()
defer cleanupTestNode(registry, node1)
defer cleanupTestNode(registry, node2)
defer cleanupTestNode(registry, node3)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
assert.NoError(t, err)
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 3)
nodeIDs := make(map[uuid.UUID]bool)
for _, node := range nodes {
nodeIDs[node.ID] = true
}
assert.True(t, nodeIDs[node1.ID])
assert.True(t, nodeIDs[node2.ID])
assert.True(t, nodeIDs[node3.ID])
}
func Test_GetAvailableNodes_WhenNoNodesExist_ReturnsEmptySlice(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.NotNil(t, nodes)
assert.Empty(t, nodes)
}
func Test_IncrementBackupsInProgress_IncrementsCounter(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
defer cleanupTestNode(registry, node)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Len(t, stats, 1)
assert.Equal(t, node.ID, stats[0].ID)
assert.Equal(t, 1, stats[0].ActiveBackups)
err = registry.IncrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
stats, err = registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Len(t, stats, 1)
assert.Equal(t, 2, stats[0].ActiveBackups)
}
func Test_DecrementBackupsInProgress_DecrementsCounter(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
defer cleanupTestNode(registry, node)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Equal(t, 3, stats[0].ActiveBackups)
err = registry.DecrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
stats, err = registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Equal(t, 2, stats[0].ActiveBackups)
err = registry.DecrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
stats, err = registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Equal(t, 1, stats[0].ActiveBackups)
}
func Test_DecrementBackupsInProgress_WhenNegative_ResetsToZero(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
defer cleanupTestNode(registry, node)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
err = registry.DecrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Len(t, stats, 1)
assert.Equal(t, 0, stats[0].ActiveBackups)
}
func Test_GetBackupNodesStats_ReturnsStatsForAllNodes(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node1 := createTestBackupNode()
node2 := createTestBackupNode()
node3 := createTestBackupNode()
defer cleanupTestNode(registry, node1)
defer cleanupTestNode(registry, node2)
defer cleanupTestNode(registry, node3)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node1.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node2.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node2.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node3.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node3.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node3.ID.String())
assert.NoError(t, err)
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Len(t, stats, 3)
statsMap := make(map[uuid.UUID]int)
for _, stat := range stats {
statsMap[stat.ID] = stat.ActiveBackups
}
assert.Equal(t, 1, statsMap[node1.ID])
assert.Equal(t, 2, statsMap[node2.ID])
assert.Equal(t, 3, statsMap[node3.ID])
}
func Test_GetBackupNodesStats_WhenNoStats_ReturnsEmptySlice(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.NotNil(t, stats)
assert.Empty(t, stats)
}
func Test_MultipleNodes_RegisteredAndQueriedCorrectly(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node1 := createTestBackupNode()
node1.ThroughputMBs = 50
node2 := createTestBackupNode()
node2.ThroughputMBs = 100
node3 := createTestBackupNode()
node3.ThroughputMBs = 150
defer cleanupTestNode(registry, node1)
defer cleanupTestNode(registry, node2)
defer cleanupTestNode(registry, node3)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
assert.NoError(t, err)
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 3)
nodeMap := make(map[uuid.UUID]BackupNode)
for _, node := range nodes {
nodeMap[node.ID] = node
}
assert.Equal(t, 50, nodeMap[node1.ID].ThroughputMBs)
assert.Equal(t, 100, nodeMap[node2.ID].ThroughputMBs)
assert.Equal(t, 150, nodeMap[node3.ID].ThroughputMBs)
}
func Test_BackupCounters_TrackedSeparatelyPerNode(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node1 := createTestBackupNode()
node2 := createTestBackupNode()
defer cleanupTestNode(registry, node1)
defer cleanupTestNode(registry, node2)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node1.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node1.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node2.ID.String())
assert.NoError(t, err)
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Len(t, stats, 2)
statsMap := make(map[uuid.UUID]int)
for _, stat := range stats {
statsMap[stat.ID] = stat.ActiveBackups
}
assert.Equal(t, 2, statsMap[node1.ID])
assert.Equal(t, 1, statsMap[node2.ID])
err = registry.DecrementBackupsInProgress(node1.ID.String())
assert.NoError(t, err)
stats, err = registry.GetBackupNodesStats()
assert.NoError(t, err)
statsMap = make(map[uuid.UUID]int)
for _, stat := range stats {
statsMap[stat.ID] = stat.ActiveBackups
}
assert.Equal(t, 1, statsMap[node1.ID])
assert.Equal(t, 1, statsMap[node2.ID])
}
func Test_GetAvailableNodes_SkipsInvalidJsonData(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
defer cleanupTestNode(registry, node)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
defer cancel()
invalidKey := nodeInfoKeyPrefix + uuid.New().String() + nodeInfoKeySuffix
registry.client.Do(
ctx,
registry.client.B().Set().Key(invalidKey).Value("invalid json data").Build(),
)
defer func() {
cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), registry.timeout)
defer cleanupCancel()
registry.client.Do(cleanupCtx, registry.client.B().Del().Key(invalidKey).Build())
}()
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 1)
assert.Equal(t, node.ID, nodes[0].ID)
}
func Test_PipelineGetKeys_HandlesEmptyKeysList(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
keyDataMap, err := registry.pipelineGetKeys([]string{})
assert.NoError(t, err)
assert.NotNil(t, keyDataMap)
assert.Empty(t, keyDataMap)
}
func Test_HearthbeatNodeInRegistry_UpdatesLastHeartbeat(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
originalHeartbeat := node.LastHeartbeat
defer cleanupTestNode(registry, node)
time.Sleep(10 * time.Millisecond)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 1)
assert.True(t, nodes[0].LastHeartbeat.After(originalHeartbeat))
}
func createTestRegistry() *BackupNodesRegistry {
return &BackupNodesRegistry{
cache_utils.GetValkeyClient(),
logger.GetLogger(),
cache_utils.DefaultCacheTimeout,
cache_utils.NewPubSubManager(),
cache_utils.NewPubSubManager(),
}
}
func createTestBackupNode() BackupNode {
return BackupNode{
ID: uuid.New(),
ThroughputMBs: 100,
LastHeartbeat: time.Now().UTC(),
}
}
func cleanupTestNode(registry *BackupNodesRegistry, node BackupNode) {
registry.UnregisterNodeFromRegistry(node)
}
func Test_AssignBackupTonode_PublishesJsonMessageToChannel(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID := uuid.New()
err := registry.AssignBackupToNode(node.ID.String(), backupID, true)
assert.NoError(t, err)
}
func Test_SubscribeNodeForBackupsAssignment_ReceivesSubmittedBackupsForMatchingNode(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID := uuid.New()
defer registry.UnsubscribeNodeForBackupsAssignments()
receivedBackupID := make(chan uuid.UUID, 1)
handler := func(id uuid.UUID, isCallNotifier bool) {
receivedBackupID <- id
}
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.AssignBackupToNode(node.ID.String(), backupID, true)
assert.NoError(t, err)
select {
case received := <-receivedBackupID:
assert.Equal(t, backupID, received)
case <-time.After(2 * time.Second):
t.Fatal("Timeout waiting for backup message")
}
}
func Test_SubscribeNodeForBackupsAssignment_FiltersOutBackupsForDifferentNode(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node1 := createTestBackupNode()
node2 := createTestBackupNode()
backupID := uuid.New()
defer registry.UnsubscribeNodeForBackupsAssignments()
receivedBackupID := make(chan uuid.UUID, 1)
handler := func(id uuid.UUID, isCallNotifier bool) {
receivedBackupID <- id
}
err := registry.SubscribeNodeForBackupsAssignment(node1.ID.String(), handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.AssignBackupToNode(node2.ID.String(), backupID, false)
assert.NoError(t, err)
select {
case <-receivedBackupID:
t.Fatal("Should not receive backup for different node")
case <-time.After(500 * time.Millisecond):
}
}
func Test_SubscribeNodeForBackupsAssignment_ParsesJsonAndBackupIdCorrectly(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID1 := uuid.New()
backupID2 := uuid.New()
defer registry.UnsubscribeNodeForBackupsAssignments()
receivedBackups := make(chan uuid.UUID, 2)
handler := func(id uuid.UUID, isCallNotifier bool) {
receivedBackups <- id
}
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.AssignBackupToNode(node.ID.String(), backupID1, true)
assert.NoError(t, err)
err = registry.AssignBackupToNode(node.ID.String(), backupID2, false)
assert.NoError(t, err)
received1 := <-receivedBackups
received2 := <-receivedBackups
receivedIDs := []uuid.UUID{received1, received2}
assert.Contains(t, receivedIDs, backupID1)
assert.Contains(t, receivedIDs, backupID2)
}
func Test_SubscribeNodeForBackupsAssignment_HandlesInvalidJson(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
defer registry.UnsubscribeNodeForBackupsAssignments()
receivedBackupID := make(chan uuid.UUID, 1)
handler := func(id uuid.UUID, isCallNotifier bool) {
receivedBackupID <- id
}
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
ctx := context.Background()
err = registry.pubsubBackups.Publish(ctx, "backup:submit", "invalid json")
assert.NoError(t, err)
select {
case <-receivedBackupID:
t.Fatal("Should not receive backup for invalid JSON")
case <-time.After(500 * time.Millisecond):
}
}
func Test_UnsubscribeNodeForBackupsAssignments_StopsReceivingMessages(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID1 := uuid.New()
backupID2 := uuid.New()
receivedBackupID := make(chan uuid.UUID, 2)
handler := func(id uuid.UUID, isCallNotifier bool) {
receivedBackupID <- id
}
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.AssignBackupToNode(node.ID.String(), backupID1, true)
assert.NoError(t, err)
received := <-receivedBackupID
assert.Equal(t, backupID1, received)
err = registry.UnsubscribeNodeForBackupsAssignments()
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.AssignBackupToNode(node.ID.String(), backupID2, false)
assert.NoError(t, err)
select {
case <-receivedBackupID:
t.Fatal("Should not receive backup after unsubscribe")
case <-time.After(500 * time.Millisecond):
}
}
func Test_SubscribeNodeForBackupsAssignment_WhenAlreadySubscribed_ReturnsError(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
defer registry.UnsubscribeNodeForBackupsAssignments()
handler := func(id uuid.UUID, isCallNotifier bool) {}
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
assert.NoError(t, err)
err = registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
assert.Error(t, err)
assert.Contains(t, err.Error(), "already subscribed")
}
func Test_MultipleNodes_EachReceivesOnlyTheirBackups(t *testing.T) {
cache_utils.ClearAllCache()
registry1 := createTestRegistry()
registry2 := createTestRegistry()
registry3 := createTestRegistry()
node1 := createTestBackupNode()
node2 := createTestBackupNode()
node3 := createTestBackupNode()
backupID1 := uuid.New()
backupID2 := uuid.New()
backupID3 := uuid.New()
defer registry1.UnsubscribeNodeForBackupsAssignments()
defer registry2.UnsubscribeNodeForBackupsAssignments()
defer registry3.UnsubscribeNodeForBackupsAssignments()
receivedBackups1 := make(chan uuid.UUID, 3)
receivedBackups2 := make(chan uuid.UUID, 3)
receivedBackups3 := make(chan uuid.UUID, 3)
handler1 := func(id uuid.UUID, isCallNotifier bool) { receivedBackups1 <- id }
handler2 := func(id uuid.UUID, isCallNotifier bool) { receivedBackups2 <- id }
handler3 := func(id uuid.UUID, isCallNotifier bool) { receivedBackups3 <- id }
err := registry1.SubscribeNodeForBackupsAssignment(node1.ID.String(), handler1)
assert.NoError(t, err)
err = registry2.SubscribeNodeForBackupsAssignment(node2.ID.String(), handler2)
assert.NoError(t, err)
err = registry3.SubscribeNodeForBackupsAssignment(node3.ID.String(), handler3)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
submitRegistry := createTestRegistry()
err = submitRegistry.AssignBackupToNode(node1.ID.String(), backupID1, true)
assert.NoError(t, err)
err = submitRegistry.AssignBackupToNode(node2.ID.String(), backupID2, false)
assert.NoError(t, err)
err = submitRegistry.AssignBackupToNode(node3.ID.String(), backupID3, true)
assert.NoError(t, err)
select {
case received := <-receivedBackups1:
assert.Equal(t, backupID1, received)
case <-time.After(2 * time.Second):
t.Fatal("Node 1 timeout waiting for backup message")
}
select {
case received := <-receivedBackups2:
assert.Equal(t, backupID2, received)
case <-time.After(2 * time.Second):
t.Fatal("Node 2 timeout waiting for backup message")
}
select {
case received := <-receivedBackups3:
assert.Equal(t, backupID3, received)
case <-time.After(2 * time.Second):
t.Fatal("Node 3 timeout waiting for backup message")
}
select {
case <-receivedBackups1:
t.Fatal("Node 1 should not receive additional backups")
case <-time.After(300 * time.Millisecond):
}
select {
case <-receivedBackups2:
t.Fatal("Node 2 should not receive additional backups")
case <-time.After(300 * time.Millisecond):
}
select {
case <-receivedBackups3:
t.Fatal("Node 3 should not receive additional backups")
case <-time.After(300 * time.Millisecond):
}
}
func Test_PublishBackupCompletion_PublishesMessageToChannel(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID := uuid.New()
err := registry.PublishBackupCompletion(node.ID.String(), backupID)
assert.NoError(t, err)
}
func Test_SubscribeForBackupsCompletions_ReceivesCompletedBackups(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID := uuid.New()
defer registry.UnsubscribeForBackupsCompletions()
receivedBackupID := make(chan uuid.UUID, 1)
receivedNodeID := make(chan string, 1)
handler := func(nodeID string, backupID uuid.UUID) {
receivedNodeID <- nodeID
receivedBackupID <- backupID
}
err := registry.SubscribeForBackupsCompletions(handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.PublishBackupCompletion(node.ID.String(), backupID)
assert.NoError(t, err)
select {
case receivedNode := <-receivedNodeID:
assert.Equal(t, node.ID.String(), receivedNode)
case <-time.After(2 * time.Second):
t.Fatal("Timeout waiting for node ID")
}
select {
case received := <-receivedBackupID:
assert.Equal(t, backupID, received)
case <-time.After(2 * time.Second):
t.Fatal("Timeout waiting for backup completion message")
}
}
func Test_SubscribeForBackupsCompletions_ParsesJsonCorrectly(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID1 := uuid.New()
backupID2 := uuid.New()
defer registry.UnsubscribeForBackupsCompletions()
receivedBackups := make(chan uuid.UUID, 2)
handler := func(nodeID string, backupID uuid.UUID) {
receivedBackups <- backupID
}
err := registry.SubscribeForBackupsCompletions(handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.PublishBackupCompletion(node.ID.String(), backupID1)
assert.NoError(t, err)
err = registry.PublishBackupCompletion(node.ID.String(), backupID2)
assert.NoError(t, err)
received1 := <-receivedBackups
received2 := <-receivedBackups
receivedIDs := []uuid.UUID{received1, received2}
assert.Contains(t, receivedIDs, backupID1)
assert.Contains(t, receivedIDs, backupID2)
}
func Test_SubscribeForBackupsCompletions_HandlesInvalidJson(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
defer registry.UnsubscribeForBackupsCompletions()
receivedBackupID := make(chan uuid.UUID, 1)
handler := func(nodeID string, backupID uuid.UUID) {
receivedBackupID <- backupID
}
err := registry.SubscribeForBackupsCompletions(handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
ctx := context.Background()
err = registry.pubsubCompletions.Publish(ctx, "backup:completion", "invalid json")
assert.NoError(t, err)
select {
case <-receivedBackupID:
t.Fatal("Should not receive backup for invalid JSON")
case <-time.After(500 * time.Millisecond):
}
}
func Test_UnsubscribeForBackupsCompletions_StopsReceivingMessages(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID1 := uuid.New()
backupID2 := uuid.New()
receivedBackupID := make(chan uuid.UUID, 2)
handler := func(nodeID string, backupID uuid.UUID) {
receivedBackupID <- backupID
}
err := registry.SubscribeForBackupsCompletions(handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.PublishBackupCompletion(node.ID.String(), backupID1)
assert.NoError(t, err)
received := <-receivedBackupID
assert.Equal(t, backupID1, received)
err = registry.UnsubscribeForBackupsCompletions()
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.PublishBackupCompletion(node.ID.String(), backupID2)
assert.NoError(t, err)
select {
case <-receivedBackupID:
t.Fatal("Should not receive backup after unsubscribe")
case <-time.After(500 * time.Millisecond):
}
}
func Test_SubscribeForBackupsCompletions_WhenAlreadySubscribed_ReturnsError(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
defer registry.UnsubscribeForBackupsCompletions()
handler := func(nodeID string, backupID uuid.UUID) {}
err := registry.SubscribeForBackupsCompletions(handler)
assert.NoError(t, err)
err = registry.SubscribeForBackupsCompletions(handler)
assert.Error(t, err)
assert.Contains(t, err.Error(), "already subscribed")
}
func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) {
cache_utils.ClearAllCache()
registry1 := createTestRegistry()
registry2 := createTestRegistry()
registry3 := createTestRegistry()
node1 := createTestBackupNode()
node2 := createTestBackupNode()
node3 := createTestBackupNode()
backupID1 := uuid.New()
backupID2 := uuid.New()
backupID3 := uuid.New()
defer registry1.UnsubscribeForBackupsCompletions()
defer registry2.UnsubscribeForBackupsCompletions()
defer registry3.UnsubscribeForBackupsCompletions()
receivedBackups1 := make(chan uuid.UUID, 3)
receivedBackups2 := make(chan uuid.UUID, 3)
receivedBackups3 := make(chan uuid.UUID, 3)
handler1 := func(nodeID string, backupID uuid.UUID) { receivedBackups1 <- backupID }
handler2 := func(nodeID string, backupID uuid.UUID) { receivedBackups2 <- backupID }
handler3 := func(nodeID string, backupID uuid.UUID) { receivedBackups3 <- backupID }
err := registry1.SubscribeForBackupsCompletions(handler1)
assert.NoError(t, err)
err = registry2.SubscribeForBackupsCompletions(handler2)
assert.NoError(t, err)
err = registry3.SubscribeForBackupsCompletions(handler3)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
publishRegistry := createTestRegistry()
err = publishRegistry.PublishBackupCompletion(node1.ID.String(), backupID1)
assert.NoError(t, err)
err = publishRegistry.PublishBackupCompletion(node2.ID.String(), backupID2)
assert.NoError(t, err)
err = publishRegistry.PublishBackupCompletion(node3.ID.String(), backupID3)
assert.NoError(t, err)
receivedAll1 := []uuid.UUID{}
receivedAll2 := []uuid.UUID{}
receivedAll3 := []uuid.UUID{}
for i := 0; i < 3; i++ {
select {
case received := <-receivedBackups1:
receivedAll1 = append(receivedAll1, received)
case <-time.After(2 * time.Second):
t.Fatal("Subscriber 1 timeout waiting for completion message")
}
}
for i := 0; i < 3; i++ {
select {
case received := <-receivedBackups2:
receivedAll2 = append(receivedAll2, received)
case <-time.After(2 * time.Second):
t.Fatal("Subscriber 2 timeout waiting for completion message")
}
}
for i := 0; i < 3; i++ {
select {
case received := <-receivedBackups3:
receivedAll3 = append(receivedAll3, received)
case <-time.After(2 * time.Second):
t.Fatal("Subscriber 3 timeout waiting for completion message")
}
}
assert.Contains(t, receivedAll1, backupID1)
assert.Contains(t, receivedAll1, backupID2)
assert.Contains(t, receivedAll1, backupID3)
assert.Contains(t, receivedAll2, backupID1)
assert.Contains(t, receivedAll2, backupID2)
assert.Contains(t, receivedAll2, backupID3)
assert.Contains(t, receivedAll3, backupID1)
assert.Contains(t, receivedAll3, backupID2)
assert.Contains(t, receivedAll3, backupID3)
}

View File

@@ -0,0 +1,600 @@
package backuping
import (
"context"
"databasus-backend/internal/config"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/storages"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/period"
"fmt"
"log/slog"
"time"
"github.com/google/uuid"
)
type BackupsScheduler struct {
backupRepository *backups_core.BackupRepository
backupConfigService *backups_config.BackupConfigService
storageService *storages.StorageService
backupCancelManager *backups_cancellation.BackupCancelManager
nodesRegistry *BackupNodesRegistry
lastBackupTime time.Time
logger *slog.Logger
backupToNodeRelations map[uuid.UUID]BackupToNodeRelation
backuperNode *BackuperNode
}
func (s *BackupsScheduler) Run(ctx context.Context) {
s.lastBackupTime = time.Now().UTC()
if config.GetEnv().IsManyNodesMode {
// wait other nodes to start
time.Sleep(1 * time.Minute)
}
if err := s.failBackupsInProgress(); err != nil {
s.logger.Error("Failed to fail backups in progress", "error", err)
panic(err)
}
if err := s.nodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted); err != nil {
s.logger.Error("Failed to subscribe to backup completions", "error", err)
panic(err)
}
defer func() {
if err := s.nodesRegistry.UnsubscribeForBackupsCompletions(); err != nil {
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
}
}()
if ctx.Err() != nil {
return
}
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.cleanOldBackups(); err != nil {
s.logger.Error("Failed to clean old backups", "error", err)
}
if err := s.checkDeadNodesAndFailBackups(); err != nil {
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
}
if err := s.runPendingBackups(); err != nil {
s.logger.Error("Failed to run pending backups", "error", err)
}
s.lastBackupTime = time.Now().UTC()
}
}
}
func (s *BackupsScheduler) IsSchedulerRunning() bool {
// if last backup time is more than 5 minutes ago, return false
return s.lastBackupTime.After(time.Now().UTC().Add(-5 * time.Minute))
}
func (s *BackupsScheduler) failBackupsInProgress() error {
backupsInProgress, err := s.backupRepository.FindByStatus(backups_core.BackupStatusInProgress)
if err != nil {
return err
}
fmt.Println("Backups in progress", len(backupsInProgress))
for _, backup := range backupsInProgress {
if err := s.backupCancelManager.CancelBackup(backup.ID); err != nil {
s.logger.Error(
"Failed to cancel backup via context manager",
"backupId",
backup.ID,
"error",
err,
)
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(backup.DatabaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
continue
}
failMessage := "Backup failed due to application restart"
backup.FailMessage = &failMessage
backup.Status = backups_core.BackupStatusFailed
backup.BackupSizeMb = 0
s.backuperNode.SendBackupNotification(
backupConfig,
backup,
backups_config.NotificationBackupFailed,
&failMessage,
)
if err := s.backupRepository.Save(backup); err != nil {
return err
}
}
return nil
}
func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool) {
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(databaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
return
}
if backupConfig.StorageID == nil {
s.logger.Error("Backup config storage ID is nil", "databaseId", databaseID)
return
}
leastBusyNodeID, err := s.calculateLeastBusyNode()
if err != nil {
s.logger.Error(
"Failed to calculate least busy node",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
return
}
backup := &backups_core.Backup{
DatabaseID: backupConfig.DatabaseID,
StorageID: *backupConfig.StorageID,
Status: backups_core.BackupStatusInProgress,
BackupSizeMb: 0,
CreatedAt: time.Now().UTC(),
}
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error(
"Failed to save backup",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
return
}
if err := s.nodesRegistry.IncrementBackupsInProgress(leastBusyNodeID.String()); err != nil {
s.logger.Error(
"Failed to increment backups in progress",
"nodeId",
leastBusyNodeID,
"backupId",
backup.ID,
"error",
err,
)
return
}
if err := s.nodesRegistry.AssignBackupToNode(leastBusyNodeID.String(), backup.ID, isCallNotifier); err != nil {
s.logger.Error(
"Failed to submit backup",
"nodeId",
leastBusyNodeID,
"backupId",
backup.ID,
"error",
err,
)
if decrementErr := s.nodesRegistry.DecrementBackupsInProgress(leastBusyNodeID.String()); decrementErr != nil {
s.logger.Error(
"Failed to decrement backups in progress after submit failure",
"nodeId",
leastBusyNodeID,
"error",
decrementErr,
)
}
return
}
if relation, exists := s.backupToNodeRelations[*leastBusyNodeID]; exists {
relation.BackupsIDs = append(relation.BackupsIDs, backup.ID)
s.backupToNodeRelations[*leastBusyNodeID] = relation
} else {
s.backupToNodeRelations[*leastBusyNodeID] = BackupToNodeRelation{
NodeID: *leastBusyNodeID,
BackupsIDs: []uuid.UUID{backup.ID},
}
}
s.logger.Info(
"Successfully triggered scheduled backup",
"databaseId",
backupConfig.DatabaseID,
"backupId",
backup.ID,
"nodeId",
leastBusyNodeID,
)
}
// GetRemainedBackupTryCount returns the number of remaining backup tries for a given backup.
// If the backup is not failed or the backup config does not allow retries, it returns 0.
// If the backup is failed and the backup config allows retries, it returns the number of remaining tries.
// If the backup is failed and the backup config does not allow retries, it returns 0.
func (s *BackupsScheduler) GetRemainedBackupTryCount(lastBackup *backups_core.Backup) int {
if lastBackup == nil {
return 0
}
if lastBackup.Status != backups_core.BackupStatusFailed {
return 0
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(lastBackup.DatabaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
return 0
}
if !backupConfig.IsRetryIfFailed {
return 0
}
maxFailedTriesCount := backupConfig.MaxFailedTriesCount
lastBackups, err := s.backupRepository.FindByDatabaseIDWithLimit(
lastBackup.DatabaseID,
maxFailedTriesCount,
)
if err != nil {
s.logger.Error("Failed to find last backups by database ID", "error", err)
return 0
}
lastFailedBackups := make([]*backups_core.Backup, 0)
for _, backup := range lastBackups {
if backup.Status == backups_core.BackupStatusFailed {
lastFailedBackups = append(lastFailedBackups, backup)
}
}
return maxFailedTriesCount - len(lastFailedBackups)
}
func (s *BackupsScheduler) cleanOldBackups() error {
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
backupStorePeriod := backupConfig.StorePeriod
if backupStorePeriod == period.PeriodForever {
continue
}
storeDuration := backupStorePeriod.ToDuration()
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
oldBackups, err := s.backupRepository.FindBackupsBeforeDate(
backupConfig.DatabaseID,
dateBeforeBackupsShouldBeDeleted,
)
if err != nil {
s.logger.Error(
"Failed to find old backups for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
continue
}
for _, backup := range oldBackups {
storage, err := s.storageService.GetStorageByID(backup.StorageID)
if err != nil {
s.logger.Error(
"Failed to get storage by ID",
"storageId",
backup.StorageID,
"error",
err,
)
continue
}
encryptor := encryption.GetFieldEncryptor()
err = storage.DeleteFile(encryptor, backup.ID)
if err != nil {
s.logger.Error("Failed to delete backup file", "backupId", backup.ID, "error", err)
}
if err := s.backupRepository.DeleteByID(backup.ID); err != nil {
s.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
continue
}
s.logger.Info(
"Deleted old backup",
"backupId",
backup.ID,
"databaseId",
backupConfig.DatabaseID,
)
}
}
return nil
}
func (s *BackupsScheduler) runPendingBackups() error {
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
if backupConfig.BackupInterval == nil {
continue
}
lastBackup, err := s.backupRepository.FindLastByDatabaseID(backupConfig.DatabaseID)
if err != nil {
s.logger.Error(
"Failed to get last backup for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
continue
}
var lastBackupTime *time.Time
if lastBackup != nil {
lastBackupTime = &lastBackup.CreatedAt
}
remainedBackupTryCount := s.GetRemainedBackupTryCount(lastBackup)
if backupConfig.BackupInterval.ShouldTriggerBackup(time.Now().UTC(), lastBackupTime) ||
remainedBackupTryCount > 0 {
s.logger.Info(
"Triggering scheduled backup",
"databaseId",
backupConfig.DatabaseID,
"intervalType",
backupConfig.BackupInterval.Interval,
)
s.StartBackup(backupConfig.DatabaseID, remainedBackupTryCount == 1)
continue
}
}
return nil
}
func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
nodes, err := s.nodesRegistry.GetAvailableNodes()
if err != nil {
return nil, fmt.Errorf("failed to get available nodes: %w", err)
}
if len(nodes) == 0 {
return nil, fmt.Errorf("no nodes available")
}
stats, err := s.nodesRegistry.GetBackupNodesStats()
if err != nil {
return nil, fmt.Errorf("failed to get backup nodes stats: %w", err)
}
statsMap := make(map[uuid.UUID]int)
for _, stat := range stats {
statsMap[stat.ID] = stat.ActiveBackups
}
var bestNode *BackupNode
var bestScore float64 = -1
now := time.Now().UTC()
for i := range nodes {
node := &nodes[i]
if now.Sub(node.LastHeartbeat) > 2*time.Minute {
continue
}
activeBackups := statsMap[node.ID]
var score float64
if node.ThroughputMBs > 0 {
score = float64(activeBackups) / float64(node.ThroughputMBs)
} else {
score = float64(activeBackups) * 1000
}
if bestNode == nil || score < bestScore {
bestNode = node
bestScore = score
}
}
if bestNode == nil {
return nil, fmt.Errorf("no suitable nodes available")
}
return &bestNode.ID, nil
}
func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUID) {
nodeID, err := uuid.Parse(nodeIDStr)
if err != nil {
s.logger.Error(
"Failed to parse node ID from completion message",
"nodeId",
nodeIDStr,
"error",
err,
)
return
}
relation, exists := s.backupToNodeRelations[nodeID]
if !exists {
s.logger.Warn(
"Received completion for unknown node",
"nodeId",
nodeID,
"backupId",
backupID,
)
return
}
newBackupIDs := make([]uuid.UUID, 0)
found := false
for _, id := range relation.BackupsIDs {
if id == backupID {
found = true
continue
}
newBackupIDs = append(newBackupIDs, id)
}
if !found {
s.logger.Warn(
"Backup not found in node's backup list",
"nodeId",
nodeID,
"backupId",
backupID,
)
return
}
if len(newBackupIDs) == 0 {
delete(s.backupToNodeRelations, nodeID)
} else {
relation.BackupsIDs = newBackupIDs
s.backupToNodeRelations[nodeID] = relation
}
if err := s.nodesRegistry.DecrementBackupsInProgress(nodeIDStr); err != nil {
s.logger.Error(
"Failed to decrement backups in progress",
"nodeId",
nodeID,
"backupId",
backupID,
"error",
err,
)
}
}
func (s *BackupsScheduler) checkDeadNodesAndFailBackups() error {
nodes, err := s.nodesRegistry.GetAvailableNodes()
if err != nil {
return fmt.Errorf("failed to get available nodes: %w", err)
}
aliveNodeIDs := make(map[uuid.UUID]bool)
now := time.Now().UTC()
for _, node := range nodes {
if now.Sub(node.LastHeartbeat) <= 2*time.Minute {
aliveNodeIDs[node.ID] = true
}
}
for nodeID, relation := range s.backupToNodeRelations {
if aliveNodeIDs[nodeID] {
continue
}
s.logger.Warn(
"Node is dead, failing its backups",
"nodeId",
nodeID,
"backupCount",
len(relation.BackupsIDs),
)
for _, backupID := range relation.BackupsIDs {
backup, err := s.backupRepository.FindByID(backupID)
if err != nil {
s.logger.Error(
"Failed to find backup for dead node",
"nodeId",
nodeID,
"backupId",
backupID,
"error",
err,
)
continue
}
failMessage := "Backup failed due to node unavailability"
backup.FailMessage = &failMessage
backup.Status = backups_core.BackupStatusFailed
backup.BackupSizeMb = 0
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error(
"Failed to save failed backup for dead node",
"nodeId",
nodeID,
"backupId",
backupID,
"error",
err,
)
continue
}
if err := s.nodesRegistry.DecrementBackupsInProgress(nodeID.String()); err != nil {
s.logger.Error(
"Failed to decrement backups in progress for dead node",
"nodeId",
nodeID,
"backupId",
backupID,
"error",
err,
)
}
s.logger.Info(
"Failed backup due to dead node",
"nodeId",
nodeID,
"backupId",
backupID,
)
}
delete(s.backupToNodeRelations, nodeID)
}
return nil
}

View File

@@ -0,0 +1,709 @@
package backuping
import (
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/period"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_RunPendingBackups_WhenLastBackupWasYesterday_CreatesNewBackup(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
// setup data
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
// cleanup backups first
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Enable backups for the database
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// add old backup
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
})
GetBackupsScheduler().runPendingBackups()
// Wait for backup to complete (runs in goroutine)
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
// assertions
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 2)
// Wait for any cleanup operations to complete before defer cleanup runs
time.Sleep(200 * time.Millisecond)
}
func Test_RunPendingBackups_WhenLastBackupWasRecentlyCompleted_SkipsBackup(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
// setup data
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
// cleanup backups first
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Enable backups for the database
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// add recent backup (1 hour ago)
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
})
GetBackupsScheduler().runPendingBackups()
time.Sleep(100 * time.Millisecond)
// assertions
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1) // Should still be 1 backup, no new backup created
// Wait for any cleanup operations to complete before defer cleanup runs
time.Sleep(200 * time.Millisecond)
}
func Test_RunPendingBackups_WhenLastBackupFailedAndRetriesDisabled_SkipsBackup(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
// setup data
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
// cleanup backups first
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Enable backups for the database with retries disabled
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
backupConfig.IsRetryIfFailed = false
backupConfig.MaxFailedTriesCount = 0
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// add failed backup
failMessage := "backup failed"
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusFailed,
FailMessage: &failMessage,
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
})
GetBackupsScheduler().runPendingBackups()
time.Sleep(100 * time.Millisecond)
// assertions
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1) // Should still be 1 backup, no retry attempted
// Wait for any cleanup operations to complete before defer cleanup runs
time.Sleep(200 * time.Millisecond)
}
func Test_RunPendingBackups_WhenLastBackupFailedAndRetriesEnabled_CreatesNewBackup(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
// setup data
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
// cleanup backups first
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Enable backups for the database with retries enabled
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
backupConfig.IsRetryIfFailed = true
backupConfig.MaxFailedTriesCount = 3
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// add failed backup
failMessage := "backup failed"
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusFailed,
FailMessage: &failMessage,
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
})
GetBackupsScheduler().runPendingBackups()
// Wait for backup to complete (runs in goroutine)
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
// assertions
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 2) // Should have 2 backups, retry was attempted
// Wait for any cleanup operations to complete before defer cleanup runs
time.Sleep(200 * time.Millisecond)
}
func Test_RunPendingBackups_WhenFailedBackupsExceedMaxRetries_SkipsBackup(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
// setup data
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
// cleanup backups first
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Enable backups for the database with retries enabled
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
backupConfig.IsRetryIfFailed = true
backupConfig.MaxFailedTriesCount = 3
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
failMessage := "backup failed"
for range 3 {
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusFailed,
FailMessage: &failMessage,
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
})
}
GetBackupsScheduler().runPendingBackups()
time.Sleep(100 * time.Millisecond)
// assertions
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 3) // Should have 3 backups, not more than max
// Wait for any cleanup operations to complete before defer cleanup runs
time.Sleep(200 * time.Millisecond)
}
func Test_RunPendingBackups_WhenBackupsDisabled_SkipsBackup(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = false
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// add old backup that would trigger new backup if enabled
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
})
GetBackupsScheduler().runPendingBackups()
time.Sleep(100 * time.Millisecond)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
// Wait for any cleanup operations to complete before defer cleanup runs
time.Sleep(200 * time.Millisecond)
}
func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegistry(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Register mock node without subscribing to backups (simulates node crash after registration)
mockNodeID := uuid.New()
err = CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC())
assert.NoError(t, err)
// Scheduler assigns backup to mock node
GetBackupsScheduler().StartBackup(database.ID, false)
time.Sleep(100 * time.Millisecond)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status)
// Verify Valkey counter was incremented when backup was assigned
stats, err := nodesRegistry.GetBackupNodesStats()
assert.NoError(t, err)
foundStat := false
for _, stat := range stats {
if stat.ID == mockNodeID {
assert.Equal(t, 1, stat.ActiveBackups)
foundStat = true
break
}
}
assert.True(t, foundStat, "Node stats should be present")
// Simulate node death by setting heartbeat older than 2-minute threshold
oldHeartbeat := time.Now().UTC().Add(-3 * time.Minute)
err = UpdateNodeHeartbeatDirectly(mockNodeID, 100, oldHeartbeat)
assert.NoError(t, err)
// Trigger dead node detection
err = GetBackupsScheduler().checkDeadNodesAndFailBackups()
assert.NoError(t, err)
// Verify backup was failed with appropriate error message
backups, err = backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
assert.Equal(t, backups_core.BackupStatusFailed, backups[0].Status)
assert.NotNil(t, backups[0].FailMessage)
assert.Contains(t, *backups[0].FailMessage, "node unavailability")
// Verify Valkey counter was decremented after backup failed
stats, err = nodesRegistry.GetBackupNodesStats()
assert.NoError(t, err)
for _, stat := range stats {
if stat.ID == mockNodeID {
assert.Equal(t, 0, stat.ActiveBackups)
}
}
// Node info should still exist in registry (not removed by checkDeadNodesAndFailBackups)
node, err := GetNodeFromRegistry(mockNodeID)
assert.NoError(t, err)
assert.NotNil(t, node)
assert.Equal(t, mockNodeID, node.ID)
time.Sleep(200 * time.Millisecond)
}
func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
t.Run("Nodes with same throughput", func(t *testing.T) {
cache_utils.ClearAllCache()
node1ID := uuid.New()
node2ID := uuid.New()
node3ID := uuid.New()
now := time.Now().UTC()
err := CreateMockNodeInRegistry(node1ID, 100, now)
assert.NoError(t, err)
err = CreateMockNodeInRegistry(node2ID, 100, now)
assert.NoError(t, err)
err = CreateMockNodeInRegistry(node3ID, 100, now)
assert.NoError(t, err)
for range 5 {
err = nodesRegistry.IncrementBackupsInProgress(node1ID.String())
assert.NoError(t, err)
}
for range 2 {
err = nodesRegistry.IncrementBackupsInProgress(node2ID.String())
assert.NoError(t, err)
}
for range 8 {
err = nodesRegistry.IncrementBackupsInProgress(node3ID.String())
assert.NoError(t, err)
}
leastBusyNodeID, err := GetBackupsScheduler().calculateLeastBusyNode()
assert.NoError(t, err)
assert.NotNil(t, leastBusyNodeID)
assert.Equal(t, node2ID, *leastBusyNodeID)
})
t.Run("Nodes with different throughput", func(t *testing.T) {
cache_utils.ClearAllCache()
node100MBsID := uuid.New()
node50MBsID := uuid.New()
now := time.Now().UTC()
err := CreateMockNodeInRegistry(node100MBsID, 100, now)
assert.NoError(t, err)
err = CreateMockNodeInRegistry(node50MBsID, 50, now)
assert.NoError(t, err)
for range 10 {
err = nodesRegistry.IncrementBackupsInProgress(node100MBsID.String())
assert.NoError(t, err)
}
err = nodesRegistry.IncrementBackupsInProgress(node50MBsID.String())
assert.NoError(t, err)
leastBusyNodeID, err := GetBackupsScheduler().calculateLeastBusyNode()
assert.NoError(t, err)
assert.NotNil(t, leastBusyNodeID)
assert.Equal(t, node50MBsID, *leastBusyNodeID)
})
}
func Test_FailBackupsInProgress_WhenSchedulerStarts_CancelsBackupsAndUpdatesStatus(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Create two in-progress backups that should be failed on scheduler restart
backup1 := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
BackupSizeMb: 10.5,
CreatedAt: time.Now().UTC().Add(-30 * time.Minute),
}
err = backupRepository.Save(backup1)
assert.NoError(t, err)
backup2 := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
BackupSizeMb: 5.2,
CreatedAt: time.Now().UTC().Add(-15 * time.Minute),
}
err = backupRepository.Save(backup2)
assert.NoError(t, err)
// Create a completed backup to verify it's not affected by failBackupsInProgress
completedBackup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 20.0,
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
}
err = backupRepository.Save(completedBackup)
assert.NoError(t, err)
// Trigger the scheduler's failBackupsInProgress logic
// This should cancel in-progress backups and mark them as failed
err = GetBackupsScheduler().failBackupsInProgress()
assert.NoError(t, err)
// Verify all backups exist and were processed correctly
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 3)
var failedCount int
var completedCount int
for _, backup := range backups {
switch backup.Status {
case backups_core.BackupStatusFailed:
failedCount++
// Verify fail message indicates application restart
assert.NotNil(t, backup.FailMessage)
assert.Equal(t, "Backup failed due to application restart", *backup.FailMessage)
// Verify backup size was reset to 0
assert.Equal(t, float64(0), backup.BackupSizeMb)
case backups_core.BackupStatusCompleted:
completedCount++
}
}
// Verify correct number of backups in each state
assert.Equal(t, 2, failedCount)
assert.Equal(t, 1, completedCount)
time.Sleep(200 * time.Millisecond)
}

View File

@@ -0,0 +1,206 @@
package backuping
import (
"context"
"fmt"
"testing"
"time"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/usecases"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_services "databasus-backend/internal/features/workspaces/services"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
func CreateTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
backups_config.GetBackupConfigController(),
)
return router
}
func CreateTestBackuperNode() *BackuperNode {
return &BackuperNode{
databases.GetDatabaseService(),
encryption.GetFieldEncryptor(),
workspaces_services.GetWorkspaceService(),
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
notifiers.GetNotifierService(),
backupCancelManager,
nodesRegistry,
logger.GetLogger(),
usecases.GetCreateBackupUsecase(),
uuid.New(),
time.Time{},
}
}
// WaitForBackupCompletion waits for a new backup to be created and completed (or failed)
// for the given database. It checks for backups with count greater than expectedInitialCount.
func WaitForBackupCompletion(
t *testing.T,
databaseID uuid.UUID,
expectedInitialCount int,
timeout time.Duration,
) {
deadline := time.Now().UTC().Add(timeout)
for time.Now().UTC().Before(deadline) {
backups, err := backupRepository.FindByDatabaseID(databaseID)
if err != nil {
t.Logf("WaitForBackupCompletion: error finding backups: %v", err)
time.Sleep(50 * time.Millisecond)
continue
}
t.Logf(
"WaitForBackupCompletion: found %d backups (expected > %d)",
len(backups),
expectedInitialCount,
)
if len(backups) > expectedInitialCount {
// Check if the newest backup has completed or failed
newestBackup := backups[0]
t.Logf("WaitForBackupCompletion: newest backup status: %s", newestBackup.Status)
if newestBackup.Status == backups_core.BackupStatusCompleted ||
newestBackup.Status == backups_core.BackupStatusFailed ||
newestBackup.Status == backups_core.BackupStatusCanceled {
t.Logf(
"WaitForBackupCompletion: backup finished with status %s",
newestBackup.Status,
)
return
}
}
time.Sleep(50 * time.Millisecond)
}
t.Logf("WaitForBackupCompletion: timeout waiting for backup to complete")
}
// StartBackuperNodeForTest starts a BackuperNode in a goroutine for testing.
// The node registers itself in the registry and subscribes to backup assignments.
// Returns a context cancel function that should be deferred to stop the node.
func StartBackuperNodeForTest(t *testing.T, backuperNode *BackuperNode) context.CancelFunc {
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
backuperNode.Run(ctx)
close(done)
}()
// Poll registry for node presence instead of fixed sleep
deadline := time.Now().UTC().Add(5 * time.Second)
for time.Now().UTC().Before(deadline) {
nodes, err := nodesRegistry.GetAvailableNodes()
if err == nil {
for _, node := range nodes {
if node.ID == backuperNode.nodeID {
t.Logf("BackuperNode registered in registry: %s", backuperNode.nodeID)
return func() {
cancel()
select {
case <-done:
t.Log("BackuperNode stopped gracefully")
case <-time.After(2 * time.Second):
t.Log("BackuperNode stop timeout")
}
}
}
}
}
time.Sleep(50 * time.Millisecond)
}
t.Fatalf("BackuperNode failed to register in registry within timeout")
return nil
}
// StopBackuperNodeForTest stops the BackuperNode by canceling its context.
// It waits for the node to unregister from the registry.
func StopBackuperNodeForTest(t *testing.T, cancel context.CancelFunc, backuperNode *BackuperNode) {
cancel()
// Wait for node to unregister from registry
deadline := time.Now().UTC().Add(2 * time.Second)
for time.Now().UTC().Before(deadline) {
nodes, err := nodesRegistry.GetAvailableNodes()
if err == nil {
found := false
for _, node := range nodes {
if node.ID == backuperNode.nodeID {
found = true
break
}
}
if !found {
t.Logf("BackuperNode unregistered from registry: %s", backuperNode.nodeID)
return
}
}
time.Sleep(50 * time.Millisecond)
}
t.Logf("BackuperNode stop completed for %s", backuperNode.nodeID)
}
func CreateMockNodeInRegistry(nodeID uuid.UUID, throughputMBs int, lastHeartbeat time.Time) error {
backupNode := BackupNode{
ID: nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: lastHeartbeat,
}
return nodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
}
func UpdateNodeHeartbeatDirectly(
nodeID uuid.UUID,
throughputMBs int,
lastHeartbeat time.Time,
) error {
backupNode := BackupNode{
ID: nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: lastHeartbeat,
}
return nodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
}
func GetNodeFromRegistry(nodeID uuid.UUID) (*BackupNode, error) {
nodes, err := nodesRegistry.GetAvailableNodes()
if err != nil {
return nil, err
}
for _, node := range nodes {
if node.ID == nodeID {
return &node, nil
}
}
return nil, fmt.Errorf("node not found")
}

View File

@@ -0,0 +1,75 @@
package backups_cancellation
import (
"context"
cache_utils "databasus-backend/internal/util/cache"
"log/slog"
"sync"
"github.com/google/uuid"
)
const backupCancelChannel = "backup:cancel"
type BackupCancelManager struct {
mu sync.RWMutex
cancelFuncs map[uuid.UUID]context.CancelFunc
pubsub *cache_utils.PubSubManager
logger *slog.Logger
}
func (m *BackupCancelManager) StartSubscription() {
ctx := context.Background()
handler := func(message string) {
backupID, err := uuid.Parse(message)
if err != nil {
m.logger.Error("Invalid backup ID in cancel message", "message", message, "error", err)
return
}
m.mu.Lock()
defer m.mu.Unlock()
cancelFunc, exists := m.cancelFuncs[backupID]
if exists {
cancelFunc()
delete(m.cancelFuncs, backupID)
m.logger.Info("Cancelled backup via Pub/Sub", "backupID", backupID)
}
}
err := m.pubsub.Subscribe(ctx, backupCancelChannel, handler)
if err != nil {
m.logger.Error("Failed to subscribe to backup cancel channel", "error", err)
} else {
m.logger.Info("Successfully subscribed to backup cancel channel")
}
}
func (m *BackupCancelManager) RegisterBackup(backupID uuid.UUID, cancelFunc context.CancelFunc) {
m.mu.Lock()
defer m.mu.Unlock()
m.cancelFuncs[backupID] = cancelFunc
m.logger.Debug("Registered backup", "backupID", backupID)
}
func (m *BackupCancelManager) CancelBackup(backupID uuid.UUID) error {
ctx := context.Background()
err := m.pubsub.Publish(ctx, backupCancelChannel, backupID.String())
if err != nil {
m.logger.Error("Failed to publish cancel message", "backupID", backupID, "error", err)
return err
}
m.logger.Info("Published backup cancel message", "backupID", backupID)
return nil
}
func (m *BackupCancelManager) UnregisterBackup(backupID uuid.UUID) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.cancelFuncs, backupID)
m.logger.Debug("Unregistered backup", "backupID", backupID)
}

View File

@@ -0,0 +1,200 @@
package backups_cancellation
import (
"context"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_RegisterBackup_BackupRegisteredSuccessfully(t *testing.T) {
manager := backupCancelManager
backupID := uuid.New()
_, cancel := context.WithCancel(context.Background())
defer cancel()
manager.RegisterBackup(backupID, cancel)
manager.mu.RLock()
_, exists := manager.cancelFuncs[backupID]
manager.mu.RUnlock()
assert.True(t, exists, "Backup should be registered")
}
func Test_UnregisterBackup_BackupUnregisteredSuccessfully(t *testing.T) {
manager := backupCancelManager
backupID := uuid.New()
_, cancel := context.WithCancel(context.Background())
defer cancel()
manager.RegisterBackup(backupID, cancel)
manager.UnregisterBackup(backupID)
manager.mu.RLock()
_, exists := manager.cancelFuncs[backupID]
manager.mu.RUnlock()
assert.False(t, exists, "Backup should be unregistered")
}
func Test_CancelBackup_OnSameInstance_BackupCancelledViaPubSub(t *testing.T) {
manager := backupCancelManager
backupID := uuid.New()
ctx, cancel := context.WithCancel(context.Background())
cancelled := false
var mu sync.Mutex
wrappedCancel := func() {
mu.Lock()
cancelled = true
mu.Unlock()
cancel()
}
manager.RegisterBackup(backupID, wrappedCancel)
manager.StartSubscription()
time.Sleep(100 * time.Millisecond)
err := manager.CancelBackup(backupID)
assert.NoError(t, err, "Cancel should not return error")
time.Sleep(500 * time.Millisecond)
mu.Lock()
wasCancelled := cancelled
mu.Unlock()
assert.True(t, wasCancelled, "Cancel function should have been called")
assert.Error(t, ctx.Err(), "Context should be cancelled")
}
func Test_CancelBackup_FromDifferentInstance_BackupCancelledOnRunningInstance(t *testing.T) {
manager1 := backupCancelManager
manager2 := backupCancelManager
backupID := uuid.New()
ctx, cancel := context.WithCancel(context.Background())
cancelled := false
var mu sync.Mutex
wrappedCancel := func() {
mu.Lock()
cancelled = true
mu.Unlock()
cancel()
}
manager1.RegisterBackup(backupID, wrappedCancel)
manager1.StartSubscription()
manager2.StartSubscription()
time.Sleep(100 * time.Millisecond)
err := manager2.CancelBackup(backupID)
assert.NoError(t, err, "Cancel should not return error")
time.Sleep(500 * time.Millisecond)
mu.Lock()
wasCancelled := cancelled
mu.Unlock()
assert.True(t, wasCancelled, "Cancel function should have been called on instance 1")
assert.Error(t, ctx.Err(), "Context should be cancelled")
}
func Test_CancelBackup_WhenBackupDoesNotExist_NoErrorReturned(t *testing.T) {
manager := backupCancelManager
manager.StartSubscription()
time.Sleep(100 * time.Millisecond)
nonExistentID := uuid.New()
err := manager.CancelBackup(nonExistentID)
assert.NoError(t, err, "Cancelling non-existent backup should not error")
}
func Test_CancelBackup_WithMultipleBackups_AllBackupsCancelled(t *testing.T) {
manager := backupCancelManager
numBackups := 5
backupIDs := make([]uuid.UUID, numBackups)
contexts := make([]context.Context, numBackups)
cancels := make([]context.CancelFunc, numBackups)
cancelledFlags := make([]bool, numBackups)
var mu sync.Mutex
for i := 0; i < numBackups; i++ {
backupIDs[i] = uuid.New()
contexts[i], cancels[i] = context.WithCancel(context.Background())
idx := i
wrappedCancel := func() {
mu.Lock()
cancelledFlags[idx] = true
mu.Unlock()
cancels[idx]()
}
manager.RegisterBackup(backupIDs[i], wrappedCancel)
}
manager.StartSubscription()
time.Sleep(100 * time.Millisecond)
for i := 0; i < numBackups; i++ {
err := manager.CancelBackup(backupIDs[i])
assert.NoError(t, err, "Cancel should not return error")
}
time.Sleep(1 * time.Second)
mu.Lock()
for i := 0; i < numBackups; i++ {
assert.True(t, cancelledFlags[i], "Backup %d should be cancelled", i)
assert.Error(t, contexts[i].Err(), "Context %d should be cancelled", i)
}
mu.Unlock()
}
func Test_CancelBackup_AfterUnregister_BackupNotCancelled(t *testing.T) {
manager := backupCancelManager
backupID := uuid.New()
_, cancel := context.WithCancel(context.Background())
defer cancel()
cancelled := false
var mu sync.Mutex
wrappedCancel := func() {
mu.Lock()
cancelled = true
mu.Unlock()
cancel()
}
manager.RegisterBackup(backupID, wrappedCancel)
manager.StartSubscription()
time.Sleep(100 * time.Millisecond)
manager.UnregisterBackup(backupID)
err := manager.CancelBackup(backupID)
assert.NoError(t, err, "Cancel should not return error")
time.Sleep(500 * time.Millisecond)
mu.Lock()
wasCancelled := cancelled
mu.Unlock()
assert.False(t, wasCancelled, "Cancel function should not be called after unregister")
}

View File

@@ -0,0 +1,25 @@
package backups_cancellation
import (
"context"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/logger"
"sync"
"github.com/google/uuid"
)
var backupCancelManager = &BackupCancelManager{
sync.RWMutex{},
make(map[uuid.UUID]context.CancelFunc),
cache_utils.NewPubSubManager(),
logger.GetLogger(),
}
func GetBackupCancelManager() *BackupCancelManager {
return backupCancelManager
}
func SetupDependencies() {
backupCancelManager.StartSubscription()
}

View File

@@ -1,6 +1,7 @@
package backups
import (
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/databases"
users_middleware "databasus-backend/internal/features/users/middleware"
"fmt"
@@ -18,11 +19,17 @@ type BackupController struct {
func (c *BackupController) RegisterRoutes(router *gin.RouterGroup) {
router.GET("/backups", c.GetBackups)
router.POST("/backups", c.MakeBackup)
router.GET("/backups/:id/file", c.GetFile)
router.POST("/backups/:id/download-token", c.GenerateDownloadToken)
router.DELETE("/backups/:id", c.DeleteBackup)
router.POST("/backups/:id/cancel", c.CancelBackup)
}
// RegisterPublicRoutes registers routes that don't require Bearer authentication
// (they have their own authentication mechanisms like download tokens)
func (c *BackupController) RegisterPublicRoutes(router *gin.RouterGroup) {
router.GET("/backups/:id/file", c.GetFile)
}
// GetBackups
// @Summary Get backups for a database
// @Description Get paginated backups for the specified database
@@ -159,17 +166,16 @@ func (c *BackupController) CancelBackup(ctx *gin.Context) {
ctx.Status(http.StatusNoContent)
}
// GetFile
// @Summary Download a backup file
// @Description Download the backup file for the specified backup
// GenerateDownloadToken
// @Summary Generate short-lived download token
// @Description Generate a token for downloading a backup file (valid for 5 minutes)
// @Tags backups
// @Param id path string true "Backup ID"
// @Success 200 {file} file
// @Success 200 {object} backups_download.GenerateDownloadTokenResponse
// @Failure 400
// @Failure 401
// @Failure 500
// @Router /backups/{id}/file [get]
func (c *BackupController) GetFile(ctx *gin.Context) {
// @Router /backups/{id}/download-token [post]
func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
@@ -182,7 +188,56 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
return
}
fileReader, backup, database, err := c.backupService.GetBackupFile(user, id)
response, err := c.backupService.GenerateDownloadToken(user, id)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, response)
}
// GetFile
// @Summary Download a backup file
// @Description Download the backup file for the specified backup using a download token
// @Tags backups
// @Param id path string true "Backup ID"
// @Param token query string true "Download token"
// @Success 200 {file} file
// @Failure 400
// @Failure 401
// @Failure 500
// @Router /backups/{id}/file [get]
func (c *BackupController) GetFile(ctx *gin.Context) {
token := ctx.Query("token")
if token == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "download token is required"})
return
}
// Get backup ID from URL
backupIDParam := ctx.Param("id")
backupID, err := uuid.Parse(backupIDParam)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backup ID"})
return
}
downloadToken, err := c.backupService.ValidateDownloadToken(token)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired download token"})
return
}
// Verify token is for the requested backup
if downloadToken.BackupID != backupID {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired download token"})
return
}
fileReader, backup, database, err := c.backupService.GetBackupFileWithoutAuth(
downloadToken.BackupID,
)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -195,6 +250,12 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
filename := c.generateBackupFilename(backup, database)
// Set Content-Length for progress tracking
if backup.BackupSizeMb > 0 {
sizeBytes := int64(backup.BackupSizeMb * 1024 * 1024)
ctx.Header("Content-Length", fmt.Sprintf("%d", sizeBytes))
}
ctx.Header("Content-Type", "application/octet-stream")
ctx.Header(
"Content-Disposition",
@@ -203,9 +264,12 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
_, err = io.Copy(ctx.Writer, fileReader)
if err != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "failed to stream file"})
fmt.Printf("Error streaming file: %v\n", err)
return
}
// Write audit log after successful download
c.backupService.WriteAuditLogForDownload(downloadToken.UserID, backup, database)
}
type MakeBackupRequest struct {
@@ -213,7 +277,7 @@ type MakeBackupRequest struct {
}
func (c *BackupController) generateBackupFilename(
backup *Backup,
backup *backups_core.Backup,
database *databases.Database,
) string {
// Format timestamp as YYYY-MM-DD_HH-mm-ss

View File

@@ -7,6 +7,7 @@ import (
"io"
"log/slog"
"net/http"
"strconv"
"strings"
"testing"
"time"
@@ -15,7 +16,10 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_download "databasus-backend/internal/features/backups/backups/download"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/postgresql"
@@ -87,7 +91,13 @@ func Test_GetBackups_PermissionsEnforced(t *testing.T) {
testUserToken = owner.Token
} else {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
*tt.workspaceRole,
owner.Token,
router,
)
testUserToken = member.Token
}
} else {
@@ -179,7 +189,13 @@ func Test_CreateBackup_PermissionsEnforced(t *testing.T) {
testUserToken = owner.Token
} else {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
*tt.workspaceRole,
owner.Token,
router,
)
testUserToken = member.Token
}
} else {
@@ -309,7 +325,13 @@ func Test_DeleteBackup_PermissionsEnforced(t *testing.T) {
testUserToken = owner.Token
} else {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
*tt.workspaceRole,
owner.Token,
router,
)
testUserToken = member.Token
}
} else {
@@ -378,7 +400,7 @@ func Test_DeleteBackup_AuditLogWritten(t *testing.T) {
assert.True(t, found, "Audit log for backup deletion not found")
}
func Test_DownloadBackup_PermissionsEnforced(t *testing.T) {
func Test_GenerateDownloadToken_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
workspaceRole *users_enums.WorkspaceRole
@@ -387,28 +409,28 @@ func Test_DownloadBackup_PermissionsEnforced(t *testing.T) {
expectedStatusCode int
}{
{
name: "workspace viewer can download backup",
name: "workspace viewer can generate token",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace member can download backup",
name: "workspace member can generate token",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "non-member cannot download backup",
name: "non-member cannot generate token",
workspaceRole: nil,
isGlobalAdmin: false,
expectSuccess: false,
expectedStatusCode: http.StatusBadRequest,
},
{
name: "global admin can download backup",
name: "global admin can generate token",
workspaceRole: nil,
isGlobalAdmin: true,
expectSuccess: true,
@@ -433,7 +455,13 @@ func Test_DownloadBackup_PermissionsEnforced(t *testing.T) {
testUserToken = owner.Token
} else {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
*tt.workspaceRole,
owner.Token,
router,
)
testUserToken = member.Token
}
} else {
@@ -441,21 +469,244 @@ func Test_DownloadBackup_PermissionsEnforced(t *testing.T) {
testUserToken = nonMember.Token
}
testResp := test_utils.MakeGetRequest(
testResp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file", backup.ID.String()),
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
"Bearer "+testUserToken,
nil,
tt.expectedStatusCode,
)
if !tt.expectSuccess {
if tt.expectSuccess {
var response backups_download.GenerateDownloadTokenResponse
err := json.Unmarshal(testResp.Body, &response)
assert.NoError(t, err)
assert.NotEmpty(t, response.Token)
assert.NotEmpty(t, response.Filename)
assert.Equal(t, backup.ID, response.BackupID)
} else {
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
})
}
}
func Test_DownloadBackup_WithValidToken_Success(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
// Generate download token
var tokenResponse backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusOK,
&tokenResponse,
)
// Download with token
testResp := test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), tokenResponse.Token),
"",
http.StatusOK,
)
// Verify response
contentDisposition := testResp.Headers.Get("Content-Disposition")
assert.Contains(t, contentDisposition, "attachment")
assert.Contains(t, contentDisposition, tokenResponse.Filename)
}
func Test_DownloadBackup_WithoutToken_Unauthorized(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
// Try to download without token
testResp := test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file", backup.ID.String()),
"",
http.StatusUnauthorized,
)
assert.Contains(t, string(testResp.Body), "download token is required")
}
func Test_DownloadBackup_WithInvalidToken_Unauthorized(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
// Try to download with invalid token
testResp := test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), "invalid-token-xyz"),
"",
http.StatusUnauthorized,
)
assert.Contains(t, string(testResp.Body), "invalid or expired download token")
}
func Test_DownloadBackup_WithExpiredToken_Unauthorized(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
// Get user for token generation
userService := users_services.GetUserService()
user, err := userService.GetUserFromToken(owner.Token)
assert.NoError(t, err)
// Create an expired token directly in the database
expiredToken := createExpiredDownloadToken(backup.ID, user.ID)
// Try to download with expired token
testResp := test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), expiredToken),
"",
http.StatusUnauthorized,
)
assert.Contains(t, string(testResp.Body), "invalid or expired download token")
// Verify audit log was NOT created for failed download
time.Sleep(100 * time.Millisecond)
auditLogService := audit_logs.GetAuditLogService()
auditLogs, err := auditLogService.GetWorkspaceAuditLogs(
workspace.ID,
&audit_logs.GetAuditLogsRequest{
Limit: 100,
Offset: 0,
},
)
assert.NoError(t, err)
found := false
for _, log := range auditLogs.AuditLogs {
if strings.Contains(log.Message, "Backup file downloaded") &&
strings.Contains(log.Message, database.Name) {
found = true
break
}
}
assert.False(t, found, "Audit log should NOT be created for failed download with expired token")
}
func Test_DownloadBackup_TokenUsedOnce_CannotReuseToken(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
// Generate download token
var tokenResponse backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusOK,
&tokenResponse,
)
// Download with token (first time - should succeed)
test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), tokenResponse.Token),
"",
http.StatusOK,
)
// Try to download again with same token (should fail)
testResp := test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), tokenResponse.Token),
"",
http.StatusUnauthorized,
)
assert.Contains(t, string(testResp.Body), "invalid or expired download token")
}
func Test_DownloadBackup_WithDifferentBackupToken_Unauthorized(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database1 := createTestDatabase("Database 1", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
configService := backups_config.GetBackupConfigService()
config1, err := configService.GetBackupConfigByDbId(database1.ID)
assert.NoError(t, err)
config1.IsBackupsEnabled = true
config1.StorageID = &storage.ID
config1.Storage = storage
_, err = configService.SaveBackupConfig(config1)
assert.NoError(t, err)
backup1 := createTestBackup(database1, owner)
database2 := createTestDatabase("Database 2", workspace.ID, owner.Token, router)
config2, err := configService.GetBackupConfigByDbId(database2.ID)
assert.NoError(t, err)
config2.IsBackupsEnabled = true
config2.StorageID = &storage.ID
config2.Storage = storage
_, err = configService.SaveBackupConfig(config2)
assert.NoError(t, err)
backup2 := createTestBackup(database2, owner)
// Generate token for backup1
var tokenResponse backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup1.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusOK,
&tokenResponse,
)
// Try to use backup1's token to download backup2 (should fail)
testResp := test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup2.ID.String(), tokenResponse.Token),
"",
http.StatusUnauthorized,
)
assert.Contains(t, string(testResp.Body), "invalid or expired download token")
}
func Test_DownloadBackup_AuditLogWritten(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
@@ -463,11 +714,24 @@ func Test_DownloadBackup_AuditLogWritten(t *testing.T) {
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
// Generate download token
var tokenResponse backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusOK,
&tokenResponse,
)
// Download with token
test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file", backup.ID.String()),
"Bearer "+owner.Token,
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), tokenResponse.Token),
"",
http.StatusOK,
)
@@ -542,11 +806,28 @@ func Test_DownloadBackup_ProperFilenameForPostgreSQL(t *testing.T) {
backup := createTestBackup(database, owner)
// Generate download token
var tokenResponse backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusOK,
&tokenResponse,
)
// Download with token
resp := test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file", backup.ID.String()),
"Bearer "+owner.Token,
fmt.Sprintf(
"/api/v1/backups/%s/file?token=%s",
backup.ID.String(),
tokenResponse.Token,
),
"",
http.StatusOK,
)
@@ -617,22 +898,22 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
_, err = configService.SaveBackupConfig(config)
assert.NoError(t, err)
backup := &Backup{
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusInProgress,
Status: backups_core.BackupStatusInProgress,
BackupSizeMb: 0,
BackupDurationMs: 0,
CreatedAt: time.Now().UTC(),
}
repo := &BackupRepository{}
repo := &backups_core.BackupRepository{}
err = repo.Save(backup)
assert.NoError(t, err)
// Register a cancellable context for the backup
GetBackupService().backupContextManager.RegisterBackup(backup.ID, func() {})
GetBackupService().backupCancelManager.RegisterBackup(backup.ID, func() {})
resp := test_utils.MakePostRequest(
t,
@@ -679,7 +960,13 @@ func createTestDatabase(
token string,
router *gin.Engine,
) *databases.Database {
testDbName := "test_db"
env := config.GetEnv()
port, err := strconv.Atoi(env.TestPostgres16Port)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
}
testDbName := "testdb"
request := databases.Database{
Name: name,
WorkspaceID: &workspaceID,
@@ -687,9 +974,9 @@ func createTestDatabase(
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
CpuCount: 1,
},
@@ -752,7 +1039,7 @@ func createTestDatabaseWithBackups(
workspace *workspaces_models.Workspace,
owner *users_dto.SignInResponseDTO,
router *gin.Engine,
) (*databases.Database, *Backup) {
) (*databases.Database, *backups_core.Backup) {
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
@@ -778,7 +1065,7 @@ func createTestDatabaseWithBackups(
func createTestBackup(
database *databases.Database,
owner *users_dto.SignInResponseDTO,
) *Backup {
) *backups_core.Backup {
userService := users_services.GetUserService()
user, err := userService.GetUserFromToken(owner.Token)
if err != nil {
@@ -790,17 +1077,17 @@ func createTestBackup(
panic("No storage found for workspace")
}
backup := &Backup{
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storages[0].ID,
Status: BackupStatusCompleted,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10.5,
BackupDurationMs: 1000,
CreatedAt: time.Now().UTC(),
}
repo := &BackupRepository{}
repo := &backups_core.BackupRepository{}
if err := repo.Save(backup); err != nil {
panic(err)
}
@@ -809,9 +1096,38 @@ func createTestBackup(
dummyContent := []byte("dummy backup content for testing")
reader := strings.NewReader(string(dummyContent))
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
if err := storages[0].SaveFile(context.Background(), encryption.GetFieldEncryptor(), logger, backup.ID, reader); err != nil {
if err := storages[0].SaveFile(
context.Background(),
encryption.GetFieldEncryptor(),
logger,
backup.ID,
reader,
); err != nil {
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
}
return backup
}
func createExpiredDownloadToken(backupID, userID uuid.UUID) string {
tokenService := GetBackupService().downloadTokenService
token, err := tokenService.Generate(backupID, userID)
if err != nil {
panic(fmt.Sprintf("Failed to generate download token: %v", err))
}
// Manually update the token to be expired
repo := &backups_download.DownloadTokenRepository{}
downloadToken, err := repo.FindByToken(token)
if err != nil || downloadToken == nil {
panic(fmt.Sprintf("Failed to find generated token: %v", err))
}
// Set expiration to 10 minutes ago
downloadToken.ExpiresAt = time.Now().UTC().Add(-10 * time.Minute)
if err := repo.Update(downloadToken); err != nil {
panic(fmt.Sprintf("Failed to update token expiration: %v", err))
}
return token
}

View File

@@ -1,4 +1,4 @@
package backups
package backups_core
type BackupStatus string

View File

@@ -1,4 +1,4 @@
package backups
package backups_core
import (
"context"

View File

@@ -1,4 +1,4 @@
package backups
package backups_core
import (
backups_config "databasus-backend/internal/features/backups/config"

View File

@@ -1,4 +1,4 @@
package backups
package backups_core
import (
"databasus-backend/internal/storage"

View File

@@ -1,9 +1,11 @@
package backups
import (
"time"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups/backuping"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_download "databasus-backend/internal/features/backups/backups/download"
"databasus-backend/internal/features/backups/backups/usecases"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -15,47 +17,31 @@ import (
"databasus-backend/internal/util/logger"
)
var backupRepository = &BackupRepository{}
var backupRepository = &backups_core.BackupRepository{}
var backupContextManager = NewBackupContextManager()
var backupCancelManager = backups_cancellation.GetBackupCancelManager()
var backupService = &BackupService{
databases.GetDatabaseService(),
storages.GetStorageService(),
backupRepository,
notifiers.GetNotifierService(),
notifiers.GetNotifierService(),
backups_config.GetBackupConfigService(),
encryption_secrets.GetSecretKeyService(),
encryption.GetFieldEncryptor(),
usecases.GetCreateBackupUsecase(),
logger.GetLogger(),
[]BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
backupContextManager,
}
var backupBackgroundService = &BackupBackgroundService{
backupService,
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
time.Now().UTC(),
logger.GetLogger(),
databaseService: databases.GetDatabaseService(),
storageService: storages.GetStorageService(),
backupRepository: backupRepository,
notifierService: notifiers.GetNotifierService(),
notificationSender: notifiers.GetNotifierService(),
backupConfigService: backups_config.GetBackupConfigService(),
secretKeyService: encryption_secrets.GetSecretKeyService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
createBackupUseCase: usecases.GetCreateBackupUsecase(),
logger: logger.GetLogger(),
backupRemoveListeners: []backups_core.BackupRemoveListener{},
workspaceService: workspaces_services.GetWorkspaceService(),
auditLogService: audit_logs.GetAuditLogService(),
backupCancelManager: backupCancelManager,
downloadTokenService: backups_download.GetDownloadTokenService(),
backupSchedulerService: backuping.GetBackupsScheduler(),
}
var backupController = &BackupController{
backupService,
}
func SetupDependencies() {
backups_config.
GetBackupConfigService().
SetDatabaseStorageChangeListener(backupService)
databases.GetDatabaseService().AddDbRemoveListener(backupService)
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
backupService: backupService,
}
func GetBackupService() *BackupService {
@@ -66,6 +52,11 @@ func GetBackupController() *BackupController {
return backupController
}
func GetBackupBackgroundService() *BackupBackgroundService {
return backupBackgroundService
func SetupDependencies() {
backups_config.
GetBackupConfigService().
SetDatabaseStorageChangeListener(backupService)
databases.GetDatabaseService().AddDbRemoveListener(backupService)
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
}

View File

@@ -0,0 +1,34 @@
package backups_download
import (
"context"
"log/slog"
"time"
)
type DownloadTokenBackgroundService struct {
downloadTokenService *DownloadTokenService
logger *slog.Logger
}
func (s *DownloadTokenBackgroundService) Run(ctx context.Context) {
s.logger.Info("Starting download token cleanup background service")
if ctx.Err() != nil {
return
}
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
s.logger.Error("Failed to clean expired download tokens", "error", err)
}
}
}
}

View File

@@ -0,0 +1,25 @@
package backups_download
import (
"databasus-backend/internal/util/logger"
)
var downloadTokenRepository = &DownloadTokenRepository{}
var downloadTokenService = &DownloadTokenService{
downloadTokenRepository,
logger.GetLogger(),
}
var downloadTokenBackgroundService = &DownloadTokenBackgroundService{
downloadTokenService,
logger.GetLogger(),
}
func GetDownloadTokenService() *DownloadTokenService {
return downloadTokenService
}
func GetDownloadTokenBackgroundService() *DownloadTokenBackgroundService {
return downloadTokenBackgroundService
}

View File

@@ -0,0 +1,9 @@
package backups_download
import "github.com/google/uuid"
type GenerateDownloadTokenResponse struct {
Token string `json:"token"`
Filename string `json:"filename"`
BackupID uuid.UUID `json:"backupId"`
}

View File

@@ -0,0 +1,21 @@
package backups_download
import (
"time"
"github.com/google/uuid"
)
type DownloadToken struct {
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey"`
Token string `json:"token" gorm:"column:token;uniqueIndex;not null"`
BackupID uuid.UUID `json:"backupId" gorm:"column:backup_id;not null"`
UserID uuid.UUID `json:"userId" gorm:"column:user_id;not null"`
ExpiresAt time.Time `json:"expiresAt" gorm:"column:expires_at;not null"`
Used bool `json:"used" gorm:"column:used;not null;default:false"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;not null"`
}
func (DownloadToken) TableName() string {
return "download_tokens"
}

View File

@@ -0,0 +1,60 @@
package backups_download
import (
"crypto/rand"
"databasus-backend/internal/storage"
"encoding/base64"
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
type DownloadTokenRepository struct{}
func (r *DownloadTokenRepository) Create(token *DownloadToken) error {
if token.ID == uuid.Nil {
token.ID = uuid.New()
}
if token.CreatedAt.IsZero() {
token.CreatedAt = time.Now().UTC()
}
return storage.GetDb().Create(token).Error
}
func (r *DownloadTokenRepository) FindByToken(token string) (*DownloadToken, error) {
var downloadToken DownloadToken
err := storage.GetDb().
Where("token = ?", token).
First(&downloadToken).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, err
}
return &downloadToken, nil
}
func (r *DownloadTokenRepository) Update(token *DownloadToken) error {
return storage.GetDb().Save(token).Error
}
func (r *DownloadTokenRepository) DeleteExpired(before time.Time) error {
return storage.GetDb().
Where("expires_at < ?", before).
Delete(&DownloadToken{}).Error
}
func GenerateSecureToken() string {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
panic("failed to generate secure random token: " + err.Error())
}
return base64.URLEncoding.EncodeToString(b)
}

View File

@@ -0,0 +1,69 @@
package backups_download
import (
"errors"
"log/slog"
"time"
"github.com/google/uuid"
)
type DownloadTokenService struct {
repository *DownloadTokenRepository
logger *slog.Logger
}
func (s *DownloadTokenService) Generate(backupID, userID uuid.UUID) (string, error) {
token := GenerateSecureToken()
downloadToken := &DownloadToken{
Token: token,
BackupID: backupID,
UserID: userID,
ExpiresAt: time.Now().UTC().Add(5 * time.Minute),
Used: false,
}
if err := s.repository.Create(downloadToken); err != nil {
return "", err
}
s.logger.Info("Generated download token", "backupId", backupID, "userId", userID)
return token, nil
}
func (s *DownloadTokenService) ValidateAndConsume(token string) (*DownloadToken, error) {
dt, err := s.repository.FindByToken(token)
if err != nil {
return nil, err
}
if dt == nil {
return nil, errors.New("invalid token")
}
if dt.Used {
return nil, errors.New("token already used")
}
if time.Now().UTC().After(dt.ExpiresAt) {
return nil, errors.New("token expired")
}
dt.Used = true
if err := s.repository.Update(dt); err != nil {
s.logger.Error("Failed to mark token as used", "error", err)
}
s.logger.Info("Token validated and consumed", "backupId", dt.BackupID)
return dt, nil
}
func (s *DownloadTokenService) CleanExpiredTokens() error {
now := time.Now().UTC()
if err := s.repository.DeleteExpired(now); err != nil {
return err
}
s.logger.Debug("Cleaned expired download tokens")
return nil
}

View File

@@ -1,6 +1,7 @@
package backups
import (
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/encryption"
"io"
)
@@ -12,17 +13,17 @@ type GetBackupsRequest struct {
}
type GetBackupsResponse struct {
Backups []*Backup `json:"backups"`
Total int64 `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
Backups []*backups_core.Backup `json:"backups"`
Total int64 `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
}
type decryptionReaderCloser struct {
type DecryptionReaderCloser struct {
*encryption.DecryptionReader
baseReader io.ReadCloser
BaseReader io.ReadCloser
}
func (r *decryptionReaderCloser) Close() error {
return r.baseReader.Close()
func (r *DecryptionReaderCloser) Close() error {
return r.BaseReader.Close()
}

View File

@@ -1,17 +1,17 @@
package backups
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"log/slog"
"slices"
"strings"
"time"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups/backuping"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_download "databasus-backend/internal/features/backups/backups/download"
"databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -28,25 +28,27 @@ import (
type BackupService struct {
databaseService *databases.DatabaseService
storageService *storages.StorageService
backupRepository *BackupRepository
backupRepository *backups_core.BackupRepository
notifierService *notifiers.NotifierService
notificationSender NotificationSender
notificationSender backups_core.NotificationSender
backupConfigService *backups_config.BackupConfigService
secretKeyService *encryption_secrets.SecretKeyService
fieldEncryptor util_encryption.FieldEncryptor
createBackupUseCase CreateBackupUsecase
createBackupUseCase backups_core.CreateBackupUsecase
logger *slog.Logger
backupRemoveListeners []BackupRemoveListener
backupRemoveListeners []backups_core.BackupRemoveListener
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
backupContextManager *BackupContextManager
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
backupCancelManager *backups_cancellation.BackupCancelManager
downloadTokenService *backups_download.DownloadTokenService
backupSchedulerService *backuping.BackupsScheduler
}
func (s *BackupService) AddBackupRemoveListener(listener BackupRemoveListener) {
func (s *BackupService) AddBackupRemoveListener(listener backups_core.BackupRemoveListener) {
s.backupRemoveListeners = append(s.backupRemoveListeners, listener)
}
@@ -89,7 +91,7 @@ func (s *BackupService) MakeBackupWithAuth(
return errors.New("insufficient permissions to create backup for this database")
}
go s.MakeBackup(databaseID, true)
s.backupSchedulerService.StartBackup(databaseID, true)
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Backup manually initiated for database: %s", database.Name),
@@ -173,7 +175,7 @@ func (s *BackupService) DeleteBackup(
return errors.New("insufficient permissions to delete backup for this database")
}
if backup.Status == BackupStatusInProgress {
if backup.Status == backups_core.BackupStatusInProgress {
return errors.New("backup is in progress")
}
@@ -190,260 +192,7 @@ func (s *BackupService) DeleteBackup(
return s.deleteBackup(backup)
}
func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
database, err := s.databaseService.GetDatabaseByID(databaseID)
if err != nil {
s.logger.Error("Failed to get database by ID", "error", err)
return
}
lastBackup, err := s.backupRepository.FindLastByDatabaseID(databaseID)
if err != nil {
s.logger.Error("Failed to find last backup by database ID", "error", err)
return
}
if lastBackup != nil && lastBackup.Status == BackupStatusInProgress {
s.logger.Error("Backup is in progress")
return
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(databaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
return
}
if backupConfig.StorageID == nil {
s.logger.Error("Backup config storage ID is not defined")
return
}
storage, err := s.storageService.GetStorageByID(*backupConfig.StorageID)
if err != nil {
s.logger.Error("Failed to get storage by ID", "error", err)
return
}
backup := &Backup{
DatabaseID: databaseID,
StorageID: storage.ID,
Status: BackupStatusInProgress,
BackupSizeMb: 0,
CreatedAt: time.Now().UTC(),
}
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error("Failed to save backup", "error", err)
return
}
start := time.Now().UTC()
backupProgressListener := func(
completedMBs float64,
) {
backup.BackupSizeMb = completedMBs
backup.BackupDurationMs = time.Since(start).Milliseconds()
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error("Failed to update backup progress", "error", err)
}
}
ctx, cancel := context.WithCancel(context.Background())
s.backupContextManager.RegisterBackup(backup.ID, cancel)
defer s.backupContextManager.UnregisterBackup(backup.ID)
backupMetadata, err := s.createBackupUseCase.Execute(
ctx,
backup.ID,
backupConfig,
database,
storage,
backupProgressListener,
)
if err != nil {
errMsg := err.Error()
// Check if backup was cancelled (not due to shutdown)
isCancelled := strings.Contains(errMsg, "backup cancelled") ||
strings.Contains(errMsg, "context canceled") ||
errors.Is(err, context.Canceled)
isShutdown := strings.Contains(errMsg, "shutdown")
if isCancelled && !isShutdown {
backup.Status = BackupStatusCanceled
backup.BackupDurationMs = time.Since(start).Milliseconds()
backup.BackupSizeMb = 0
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error("Failed to save cancelled backup", "error", err)
}
// Delete partial backup from storage
storage, storageErr := s.storageService.GetStorageByID(backup.StorageID)
if storageErr == nil {
if deleteErr := storage.DeleteFile(s.fieldEncryptor, backup.ID); deleteErr != nil {
s.logger.Error(
"Failed to delete partial backup file",
"backupId",
backup.ID,
"error",
deleteErr,
)
}
}
return
}
backup.FailMessage = &errMsg
backup.Status = BackupStatusFailed
backup.BackupDurationMs = time.Since(start).Milliseconds()
backup.BackupSizeMb = 0
if updateErr := s.databaseService.SetBackupError(databaseID, errMsg); updateErr != nil {
s.logger.Error(
"Failed to update database last backup time",
"databaseId",
databaseID,
"error",
updateErr,
)
}
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error("Failed to save backup", "error", err)
}
s.SendBackupNotification(
backupConfig,
backup,
backups_config.NotificationBackupFailed,
&errMsg,
)
return
}
backup.Status = BackupStatusCompleted
backup.BackupDurationMs = time.Since(start).Milliseconds()
// Update backup with encryption metadata if provided
if backupMetadata != nil {
backup.EncryptionSalt = backupMetadata.EncryptionSalt
backup.EncryptionIV = backupMetadata.EncryptionIV
backup.Encryption = backupMetadata.Encryption
}
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error("Failed to save backup", "error", err)
return
}
// Update database last backup time
now := time.Now().UTC()
if updateErr := s.databaseService.SetLastBackupTime(databaseID, now); updateErr != nil {
s.logger.Error(
"Failed to update database last backup time",
"databaseId",
databaseID,
"error",
updateErr,
)
}
if backup.Status != BackupStatusCompleted && !isLastTry {
return
}
s.SendBackupNotification(
backupConfig,
backup,
backups_config.NotificationBackupSuccess,
nil,
)
}
func (s *BackupService) SendBackupNotification(
backupConfig *backups_config.BackupConfig,
backup *Backup,
notificationType backups_config.BackupNotificationType,
errorMessage *string,
) {
database, err := s.databaseService.GetDatabaseByID(backupConfig.DatabaseID)
if err != nil {
return
}
workspace, err := s.workspaceService.GetWorkspaceByID(*database.WorkspaceID)
if err != nil {
return
}
for _, notifier := range database.Notifiers {
if !slices.Contains(
backupConfig.SendNotificationsOn,
notificationType,
) {
continue
}
title := ""
switch notificationType {
case backups_config.NotificationBackupFailed:
title = fmt.Sprintf(
"❌ Backup failed for database \"%s\" (workspace \"%s\")",
database.Name,
workspace.Name,
)
case backups_config.NotificationBackupSuccess:
title = fmt.Sprintf(
"✅ Backup completed for database \"%s\" (workspace \"%s\")",
database.Name,
workspace.Name,
)
}
message := ""
if errorMessage != nil {
message = *errorMessage
} else {
// Format size conditionally
var sizeStr string
if backup.BackupSizeMb < 1024 {
sizeStr = fmt.Sprintf("%.2f MB", backup.BackupSizeMb)
} else {
sizeGB := backup.BackupSizeMb / 1024
sizeStr = fmt.Sprintf("%.2f GB", sizeGB)
}
// Format duration as "0m 0s 0ms"
totalMs := backup.BackupDurationMs
minutes := totalMs / (1000 * 60)
seconds := (totalMs % (1000 * 60)) / 1000
durationStr := fmt.Sprintf("%dm %ds", minutes, seconds)
message = fmt.Sprintf(
"Backup completed successfully in %s.\nCompressed backup size: %s",
durationStr,
sizeStr,
)
}
s.notificationSender.SendNotification(
&notifier,
title,
message,
)
}
}
func (s *BackupService) GetBackup(backupID uuid.UUID) (*Backup, error) {
func (s *BackupService) GetBackup(backupID uuid.UUID) (*backups_core.Backup, error) {
return s.backupRepository.FindByID(backupID)
}
@@ -473,11 +222,11 @@ func (s *BackupService) CancelBackup(
return errors.New("insufficient permissions to cancel backup for this database")
}
if backup.Status != BackupStatusInProgress {
if backup.Status != backups_core.BackupStatusInProgress {
return errors.New("backup is not in progress")
}
if err := s.backupContextManager.CancelBackup(backupID); err != nil {
if err := s.backupCancelManager.CancelBackup(backupID); err != nil {
return err
}
@@ -497,7 +246,7 @@ func (s *BackupService) CancelBackup(
func (s *BackupService) GetBackupFile(
user *users_models.User,
backupID uuid.UUID,
) (io.ReadCloser, *Backup, *databases.Database, error) {
) (io.ReadCloser, *backups_core.Backup, *databases.Database, error) {
backup, err := s.backupRepository.FindByID(backupID)
if err != nil {
return nil, nil, nil, err
@@ -543,7 +292,7 @@ func (s *BackupService) GetBackupFile(
return reader, backup, database, nil
}
func (s *BackupService) deleteBackup(backup *Backup) error {
func (s *BackupService) deleteBackup(backup *backups_core.Backup) error {
for _, listener := range s.backupRemoveListeners {
if err := listener.OnBeforeBackupRemove(backup); err != nil {
return err
@@ -569,7 +318,7 @@ func (s *BackupService) deleteBackup(backup *Backup) error {
func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
dbBackupsInProgress, err := s.backupRepository.FindByDatabaseIdAndStatus(
databaseID,
BackupStatusInProgress,
backups_core.BackupStatusInProgress,
)
if err != nil {
return err
@@ -678,8 +427,120 @@ func (s *BackupService) getBackupReader(backupID uuid.UUID) (io.ReadCloser, erro
s.logger.Info("Returning encrypted backup with decryption", "backupId", backupID)
return &decryptionReaderCloser{
decryptionReader,
fileReader,
return &DecryptionReaderCloser{
DecryptionReader: decryptionReader,
BaseReader: fileReader,
}, nil
}
func (s *BackupService) GenerateDownloadToken(
user *users_models.User,
backupID uuid.UUID,
) (*backups_download.GenerateDownloadTokenResponse, error) {
backup, err := s.backupRepository.FindByID(backupID)
if err != nil {
return nil, err
}
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
if err != nil {
return nil, err
}
if database.WorkspaceID == nil {
return nil, errors.New("cannot download backup for database without workspace")
}
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, user)
if err != nil {
return nil, err
}
if !canAccess {
return nil, errors.New("insufficient permissions to download backup for this database")
}
token, err := s.downloadTokenService.Generate(backupID, user.ID)
if err != nil {
return nil, err
}
filename := s.generateBackupFilename(backup, database)
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Download token generated for backup of database: %s", database.Name),
&user.ID,
database.WorkspaceID,
)
return &backups_download.GenerateDownloadTokenResponse{
Token: token,
Filename: filename,
BackupID: backupID,
}, nil
}
func (s *BackupService) ValidateDownloadToken(
token string,
) (*backups_download.DownloadToken, error) {
return s.downloadTokenService.ValidateAndConsume(token)
}
func (s *BackupService) GetBackupFileWithoutAuth(
backupID uuid.UUID,
) (io.ReadCloser, *backups_core.Backup, *databases.Database, error) {
backup, err := s.backupRepository.FindByID(backupID)
if err != nil {
return nil, nil, nil, err
}
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
if err != nil {
return nil, nil, nil, err
}
reader, err := s.getBackupReader(backupID)
if err != nil {
return nil, nil, nil, err
}
return reader, backup, database, nil
}
func (s *BackupService) WriteAuditLogForDownload(
userID uuid.UUID,
backup *backups_core.Backup,
database *databases.Database,
) {
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Backup file downloaded for database: %s (ID: %s)",
database.Name,
backup.ID.String(),
),
&userID,
database.WorkspaceID,
)
}
func (s *BackupService) generateBackupFilename(
backup *backups_core.Backup,
database *databases.Database,
) string {
timestamp := backup.CreatedAt.Format("2006-01-02_15-04-05")
safeName := sanitizeFilename(database.Name)
extension := s.getBackupExtension(database.Type)
return fmt.Sprintf("%s_backup_%s%s", safeName, timestamp, extension)
}
func (s *BackupService) getBackupExtension(dbType databases.DatabaseType) string {
switch dbType {
case databases.DatabaseTypeMysql, databases.DatabaseTypeMariadb:
return ".sql.zst"
case databases.DatabaseTypePostgres:
return ".dump"
case databases.DatabaseTypeMongodb:
return ".archive"
default:
return ".backup"
}
}

View File

@@ -4,6 +4,7 @@ import (
"testing"
"time"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
@@ -14,13 +15,19 @@ import (
)
func CreateTestRouter() *gin.Engine {
return workspaces_testing.CreateTestRouter(
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
backups_config.GetBackupConfigController(),
GetBackupController(),
)
// Register public routes (no auth required - token-based)
v1 := router.Group("/api/v1")
GetBackupController().RegisterPublicRoutes(v1)
return router
}
// WaitForBackupCompletion waits for a new backup to be created and completed (or failed)
@@ -52,9 +59,9 @@ func WaitForBackupCompletion(
newestBackup := backups[0]
t.Logf("WaitForBackupCompletion: newest backup status: %s", newestBackup.Status)
if newestBackup.Status == BackupStatusCompleted ||
newestBackup.Status == BackupStatusFailed ||
newestBackup.Status == BackupStatusCanceled {
if newestBackup.Status == backups_core.BackupStatusCompleted ||
newestBackup.Status == backups_core.BackupStatusFailed ||
newestBackup.Status == backups_core.BackupStatusCanceled {
t.Logf(
"WaitForBackupCompletion: backup finished with status %s",
newestBackup.Status,

View File

@@ -107,16 +107,22 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpArgs(
"--user=" + mdb.Username,
"--single-transaction",
"--routines",
"--triggers",
"--events",
"--quick",
"--verbose",
}
if mdb.HasPrivilege("TRIGGER") {
args = append(args, "--triggers")
}
if mdb.HasPrivilege("EVENT") {
args = append(args, "--events")
}
args = append(args, "--compress")
if mdb.IsHttps {
args = append(args, "--ssl")
args = append(args, "--skip-ssl-verify-server-cert")
}
if mdb.Database != nil && *mdb.Database != "" {
@@ -260,11 +266,24 @@ func (uc *CreateMariadbBackupUsecase) createTempMyCnfFile(
mdbConfig *mariadbtypes.MariadbDatabase,
password string,
) (string, error) {
tempDir, err := os.MkdirTemp(config.GetEnv().TempFolder, "mycnf_"+uuid.New().String())
tempFolder := config.GetEnv().TempFolder
if err := os.MkdirAll(tempFolder, 0700); err != nil {
return "", fmt.Errorf("failed to ensure temp folder exists: %w", err)
}
if err := os.Chmod(tempFolder, 0700); err != nil {
return "", fmt.Errorf("failed to set temp folder permissions: %w", err)
}
tempDir, err := os.MkdirTemp(tempFolder, "mycnf_"+uuid.New().String())
if err != nil {
return "", fmt.Errorf("failed to create temp directory: %w", err)
}
if err := os.Chmod(tempDir, 0700); err != nil {
_ = os.RemoveAll(tempDir)
return "", fmt.Errorf("failed to set temp directory permissions: %w", err)
}
myCnfFile := filepath.Join(tempDir, ".my.cnf")
content := fmt.Sprintf(`[client]
@@ -282,6 +301,7 @@ port=%d
err = os.WriteFile(myCnfFile, []byte(content), 0600)
if err != nil {
_ = os.RemoveAll(tempDir)
return "", fmt.Errorf("failed to write .my.cnf: %w", err)
}

View File

@@ -105,13 +105,18 @@ func (uc *CreateMysqlBackupUsecase) buildMysqldumpArgs(my *mysqltypes.MysqlDatab
"--user=" + my.Username,
"--single-transaction",
"--routines",
"--triggers",
"--events",
"--set-gtid-purged=OFF",
"--quick",
"--verbose",
}
if my.HasPrivilege("TRIGGER") {
args = append(args, "--triggers")
}
if my.HasPrivilege("EVENT") {
args = append(args, "--events")
}
args = append(args, uc.getNetworkCompressionArgs(my.Version)...)
if my.IsHttps {
@@ -275,11 +280,24 @@ func (uc *CreateMysqlBackupUsecase) createTempMyCnfFile(
myConfig *mysqltypes.MysqlDatabase,
password string,
) (string, error) {
tempDir, err := os.MkdirTemp(config.GetEnv().TempFolder, "mycnf_"+uuid.New().String())
tempFolder := config.GetEnv().TempFolder
if err := os.MkdirAll(tempFolder, 0700); err != nil {
return "", fmt.Errorf("failed to ensure temp folder exists: %w", err)
}
if err := os.Chmod(tempFolder, 0700); err != nil {
return "", fmt.Errorf("failed to set temp folder permissions: %w", err)
}
tempDir, err := os.MkdirTemp(tempFolder, "mycnf_"+uuid.New().String())
if err != nil {
return "", fmt.Errorf("failed to create temp directory: %w", err)
}
if err := os.Chmod(tempDir, 0700); err != nil {
_ = os.RemoveAll(tempDir)
return "", fmt.Errorf("failed to set temp directory permissions: %w", err)
}
myCnfFile := filepath.Join(tempDir, ".my.cnf")
content := fmt.Sprintf(`[client]
@@ -295,6 +313,7 @@ port=%d
err = os.WriteFile(myCnfFile, []byte(content), 0600)
if err != nil {
_ = os.RemoveAll(tempDir)
return "", fmt.Errorf("failed to write .my.cnf: %w", err)
}

View File

@@ -135,7 +135,14 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
cmd := exec.CommandContext(ctx, pgBin, args...)
uc.logger.Info("Executing PostgreSQL backup command", "command", cmd.String())
if err := uc.setupPgEnvironment(cmd, pgpassFile, db.Postgresql.IsHttps, password, db.Postgresql.CpuCount, pgBin); err != nil {
if err := uc.setupPgEnvironment(
cmd,
pgpassFile,
db.Postgresql.IsHttps,
password,
db.Postgresql.CpuCount,
pgBin,
); err != nil {
return nil, err
}
@@ -750,14 +757,28 @@ func (uc *CreatePostgresqlBackupUsecase) createTempPgpassFile(
escapedPassword,
)
tempDir, err := os.MkdirTemp("", "pgpass")
tempFolder := config.GetEnv().TempFolder
if err := os.MkdirAll(tempFolder, 0700); err != nil {
return "", fmt.Errorf("failed to ensure temp folder exists: %w", err)
}
if err := os.Chmod(tempFolder, 0700); err != nil {
return "", fmt.Errorf("failed to set temp folder permissions: %w", err)
}
tempDir, err := os.MkdirTemp(tempFolder, "pgpass_"+uuid.New().String())
if err != nil {
return "", fmt.Errorf("failed to create temporary directory: %w", err)
}
if err := os.Chmod(tempDir, 0700); err != nil {
_ = os.RemoveAll(tempDir)
return "", fmt.Errorf("failed to set temporary directory permissions: %w", err)
}
pgpassFile := filepath.Join(tempDir, ".pgpass")
err = os.WriteFile(pgpassFile, []byte(pgpassContent), 0600)
if err != nil {
_ = os.RemoveAll(tempDir)
return "", fmt.Errorf("failed to write temporary .pgpass file: %w", err)
}

View File

@@ -2,13 +2,16 @@ package backups_config
import (
"encoding/json"
"fmt"
"net/http"
"strconv"
"testing"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/postgresql"
"databasus-backend/internal/features/intervals"
@@ -94,7 +97,13 @@ func Test_SaveBackupConfig_PermissionsEnforced(t *testing.T) {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
*tt.workspaceRole,
owner.Token,
router,
)
testUserToken = member.Token
}
@@ -241,7 +250,13 @@ func Test_GetBackupConfigByDbID_PermissionsEnforced(t *testing.T) {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
*tt.workspaceRole,
owner.Token,
router,
)
testUserToken = member.Token
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
@@ -1434,7 +1449,13 @@ func createTestDatabaseViaAPI(
token string,
router *gin.Engine,
) *databases.Database {
testDbName := "test_db"
env := config.GetEnv()
port, err := strconv.Atoi(env.TestPostgres16Port)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
}
testDbName := "testdb"
request := databases.Database{
WorkspaceID: &workspaceID,
Name: name,
@@ -1442,9 +1463,9 @@ func createTestDatabaseViaAPI(
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
CpuCount: 1,
},
@@ -1459,7 +1480,9 @@ func createTestDatabaseViaAPI(
)
if w.Code != http.StatusCreated {
panic("Failed to create database")
panic(
fmt.Sprintf("Failed to create database. Status: %d, Body: %s", w.Code, w.Body.String()),
)
}
var database databases.Database

View File

@@ -392,13 +392,13 @@ func (c *DatabaseController) IsUserReadOnly(ctx *gin.Context) {
return
}
isReadOnly, err := c.databaseService.IsUserReadOnly(user, &request)
isReadOnly, privileges, err := c.databaseService.IsUserReadOnly(user, &request)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, IsReadOnlyResponse{IsReadOnly: isReadOnly})
ctx.JSON(http.StatusOK, IsReadOnlyResponse{IsReadOnly: isReadOnly, Privileges: privileges})
}
// CreateReadOnlyUser

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"strconv"
"strings"
"testing"
@@ -11,6 +12,7 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
"databasus-backend/internal/features/databases/databases/mariadb"
"databasus-backend/internal/features/databases/databases/mongodb"
"databasus-backend/internal/features/databases/databases/postgresql"
@@ -32,6 +34,71 @@ func createTestRouter() *gin.Engine {
return router
}
func getTestPostgresConfig() *postgresql.PostgresqlDatabase {
env := config.GetEnv()
port, err := strconv.Atoi(env.TestPostgres16Port)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
}
testDbName := "testdb"
return &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
CpuCount: 1,
}
}
func getTestMariadbConfig() *mariadb.MariadbDatabase {
env := config.GetEnv()
portStr := env.TestMariadb1011Port
if portStr == "" {
portStr = "33111"
}
port, err := strconv.Atoi(portStr)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_MARIADB_1011_PORT: %v", err))
}
testDbName := "testdb"
return &mariadb.MariadbDatabase{
Version: tools.MariadbVersion1011,
Host: "localhost",
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
}
}
func getTestMongodbConfig() *mongodb.MongodbDatabase {
env := config.GetEnv()
portStr := env.TestMongodb70Port
if portStr == "" {
portStr = "27070"
}
port, err := strconv.Atoi(portStr)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_MONGODB_70_PORT: %v", err))
}
return &mongodb.MongodbDatabase{
Version: tools.MongodbVersion7,
Host: "localhost",
Port: port,
Username: "root",
Password: "rootpassword",
Database: "testdb",
AuthDatabase: "admin",
IsHttps: false,
CpuCount: 1,
}
}
func Test_CreateDatabase_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
@@ -84,24 +151,21 @@ func Test_CreateDatabase_PermissionsEnforced(t *testing.T) {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
*tt.workspaceRole,
owner.Token,
router,
)
testUserToken = member.Token
}
testDbName := "test_db"
request := Database{
Name: "Test Database",
WorkspaceID: &workspace.ID,
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Database: &testDbName,
CpuCount: 1,
},
Postgresql: getTestPostgresConfig(),
}
var response Database
@@ -132,20 +196,11 @@ func Test_CreateDatabase_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testin
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
testDbName := "test_db"
request := Database{
Name: "Test Database",
WorkspaceID: &workspace.ID,
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Database: &testDbName,
CpuCount: 1,
},
Postgresql: getTestPostgresConfig(),
}
testResp := test_utils.MakePostRequest(
@@ -214,7 +269,13 @@ func Test_UpdateDatabase_PermissionsEnforced(t *testing.T) {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
*tt.workspaceRole,
owner.Token,
router,
)
testUserToken = member.Token
}
@@ -316,7 +377,13 @@ func Test_DeleteDatabase_PermissionsEnforced(t *testing.T) {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
*tt.workspaceRole,
owner.Token,
router,
)
testUserToken = member.Token
}
@@ -381,7 +448,13 @@ func Test_GetDatabase_PermissionsEnforced(t *testing.T) {
testUser = admin.Token
} else if tt.userRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.userRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
*tt.userRole,
owner.Token,
router,
)
testUser = member.Token
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
@@ -605,7 +678,13 @@ func Test_CopyDatabase_PermissionsEnforced(t *testing.T) {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
*tt.workspaceRole,
owner.Token,
router,
)
testUserToken = member.Token
}
@@ -737,7 +816,13 @@ func createTestDatabaseViaAPI(
token string,
router *gin.Engine,
) *Database {
testDbName := "test_db"
env := config.GetEnv()
port, err := strconv.Atoi(env.TestPostgres16Port)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
}
testDbName := "testdb"
request := Database{
Name: name,
WorkspaceID: &workspaceID,
@@ -745,9 +830,9 @@ func createTestDatabaseViaAPI(
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
CpuCount: 1,
},
@@ -780,21 +865,14 @@ func Test_CreateDatabase_PasswordIsEncryptedInDB(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
testDbName := "test_db"
plainPassword := "my-super-secret-password-123"
pgConfig := getTestPostgresConfig()
plainPassword := "testpassword"
pgConfig.Password = plainPassword
request := Database{
Name: "Test Database",
WorkspaceID: &workspace.ID,
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: plainPassword,
Database: &testDbName,
CpuCount: 1,
},
Postgresql: pgConfig,
}
var createdDatabase Database
@@ -854,38 +932,23 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
name: "PostgreSQL Database",
databaseType: DatabaseTypePostgres,
createDatabase: func(workspaceID uuid.UUID) *Database {
testDbName := "test_db"
pgConfig := getTestPostgresConfig()
return &Database{
WorkspaceID: &workspaceID,
Name: "Test PostgreSQL Database",
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "original-password-secret",
Database: &testDbName,
CpuCount: 1,
},
Postgresql: pgConfig,
}
},
updateDatabase: func(workspaceID uuid.UUID, databaseID uuid.UUID) *Database {
testDbName := "updated_test_db"
pgConfig := getTestPostgresConfig()
pgConfig.Password = ""
return &Database{
ID: databaseID,
WorkspaceID: &workspaceID,
Name: "Updated PostgreSQL Database",
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion17,
Host: "updated-host",
Port: 5433,
Username: "updated_user",
Password: "",
Database: &testDbName,
CpuCount: 1,
},
Postgresql: pgConfig,
}
},
verifySensitiveData: func(t *testing.T, database *Database) {
@@ -895,7 +958,7 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
encryptor := encryption.GetFieldEncryptor()
decrypted, err := encryptor.Decrypt(database.ID, database.Postgresql.Password)
assert.NoError(t, err)
assert.Equal(t, "original-password-secret", decrypted)
assert.Equal(t, "testpassword", decrypted)
},
verifyHiddenData: func(t *testing.T, database *Database) {
assert.Equal(t, "", database.Postgresql.Password)
@@ -905,36 +968,23 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
name: "MariaDB Database",
databaseType: DatabaseTypeMariadb,
createDatabase: func(workspaceID uuid.UUID) *Database {
testDbName := "test_db"
mariaConfig := getTestMariadbConfig()
return &Database{
WorkspaceID: &workspaceID,
Name: "Test MariaDB Database",
Type: DatabaseTypeMariadb,
Mariadb: &mariadb.MariadbDatabase{
Version: tools.MariadbVersion1011,
Host: "localhost",
Port: 3306,
Username: "root",
Password: "original-password-secret",
Database: &testDbName,
},
Mariadb: mariaConfig,
}
},
updateDatabase: func(workspaceID uuid.UUID, databaseID uuid.UUID) *Database {
testDbName := "updated_test_db"
mariaConfig := getTestMariadbConfig()
mariaConfig.Password = ""
return &Database{
ID: databaseID,
WorkspaceID: &workspaceID,
Name: "Updated MariaDB Database",
Type: DatabaseTypeMariadb,
Mariadb: &mariadb.MariadbDatabase{
Version: tools.MariadbVersion114,
Host: "updated-host",
Port: 3307,
Username: "updated_user",
Password: "",
Database: &testDbName,
},
Mariadb: mariaConfig,
}
},
verifySensitiveData: func(t *testing.T, database *Database) {
@@ -944,7 +994,7 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
encryptor := encryption.GetFieldEncryptor()
decrypted, err := encryptor.Decrypt(database.ID, database.Mariadb.Password)
assert.NoError(t, err)
assert.Equal(t, "original-password-secret", decrypted)
assert.Equal(t, "testpassword", decrypted)
},
verifyHiddenData: func(t *testing.T, database *Database) {
assert.Equal(t, "", database.Mariadb.Password)
@@ -954,40 +1004,23 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
name: "MongoDB Database",
databaseType: DatabaseTypeMongodb,
createDatabase: func(workspaceID uuid.UUID) *Database {
mongoConfig := getTestMongodbConfig()
return &Database{
WorkspaceID: &workspaceID,
Name: "Test MongoDB Database",
Type: DatabaseTypeMongodb,
Mongodb: &mongodb.MongodbDatabase{
Version: tools.MongodbVersion7,
Host: "localhost",
Port: 27017,
Username: "root",
Password: "original-password-secret",
Database: "test_db",
AuthDatabase: "admin",
IsHttps: false,
CpuCount: 1,
},
Mongodb: mongoConfig,
}
},
updateDatabase: func(workspaceID uuid.UUID, databaseID uuid.UUID) *Database {
mongoConfig := getTestMongodbConfig()
mongoConfig.Password = ""
return &Database{
ID: databaseID,
WorkspaceID: &workspaceID,
Name: "Updated MongoDB Database",
Type: DatabaseTypeMongodb,
Mongodb: &mongodb.MongodbDatabase{
Version: tools.MongodbVersion8,
Host: "updated-host",
Port: 27018,
Username: "updated_user",
Password: "",
Database: "updated_test_db",
AuthDatabase: "admin",
IsHttps: false,
CpuCount: 1,
},
Mongodb: mongoConfig,
}
},
verifySensitiveData: func(t *testing.T, database *Database) {
@@ -997,7 +1030,7 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
encryptor := encryption.GetFieldEncryptor()
decrypted, err := encryptor.Decrypt(database.ID, database.Mongodb.Password)
assert.NoError(t, err)
assert.Equal(t, "original-password-secret", decrypted)
assert.Equal(t, "rootpassword", decrypted)
},
verifyHiddenData: func(t *testing.T, database *Database) {
assert.Equal(t, "", database.Mongodb.Password)

View File

@@ -2,18 +2,20 @@ package mariadb
import (
"context"
"crypto/tls"
"database/sql"
"errors"
"fmt"
"log/slog"
"regexp"
"sort"
"strings"
"time"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/tools"
_ "github.com/go-sql-driver/mysql"
"github.com/go-sql-driver/mysql"
"github.com/google/uuid"
)
@@ -23,12 +25,13 @@ type MariadbDatabase struct {
Version tools.MariadbVersion `json:"version" gorm:"type:text;not null"`
Host string `json:"host" gorm:"type:text;not null"`
Port int `json:"port" gorm:"type:int;not null"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
Database *string `json:"database" gorm:"type:text"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
Host string `json:"host" gorm:"type:text;not null"`
Port int `json:"port" gorm:"type:int;not null"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
Database *string `json:"database" gorm:"type:text"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
Privileges string `json:"privileges" gorm:"column:privileges;type:text;not null;default:''"`
}
func (m *MariadbDatabase) TableName() string {
@@ -94,6 +97,16 @@ func (m *MariadbDatabase) TestConnection(
}
m.Version = detectedVersion
privileges, err := detectPrivileges(ctx, db, *m.Database)
if err != nil {
return err
}
m.Privileges = privileges
if err := checkBackupPermissions(m.Privileges); err != nil {
return err
}
return nil
}
@@ -111,6 +124,7 @@ func (m *MariadbDatabase) Update(incoming *MariadbDatabase) {
m.Username = incoming.Username
m.Database = incoming.Database
m.IsHttps = incoming.IsHttps
m.Privileges = incoming.Privileges
if incoming.Password != "" {
m.Password = incoming.Password
@@ -131,15 +145,48 @@ func (m *MariadbDatabase) EncryptSensitiveFields(
return nil
}
func (m *MariadbDatabase) PopulateVersionIfEmpty(
func (m *MariadbDatabase) PopulateDbData(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error {
if m.Version != "" {
if m.Database == nil || *m.Database == "" {
return nil
}
return m.PopulateVersion(logger, encryptor, databaseID)
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
if err != nil {
return fmt.Errorf("failed to decrypt password: %w", err)
}
dsn := m.buildDSN(password, *m.Database)
db, err := sql.Open("mysql", dsn)
if err != nil {
return fmt.Errorf("failed to connect to database: %w", err)
}
defer func() {
if closeErr := db.Close(); closeErr != nil {
logger.Error("Failed to close connection", "error", closeErr)
}
}()
detectedVersion, err := detectMariadbVersion(ctx, db)
if err != nil {
return err
}
m.Version = detectedVersion
privileges, err := detectPrivileges(ctx, db, *m.Database)
if err != nil {
return err
}
m.Privileges = privileges
return nil
}
func (m *MariadbDatabase) PopulateVersion(
@@ -175,8 +222,8 @@ func (m *MariadbDatabase) PopulateVersion(
if err != nil {
return err
}
m.Version = detectedVersion
return nil
}
@@ -185,17 +232,17 @@ func (m *MariadbDatabase) IsUserReadOnly(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) (bool, error) {
) (bool, []string, error) {
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
if err != nil {
return false, fmt.Errorf("failed to decrypt password: %w", err)
return false, nil, fmt.Errorf("failed to decrypt password: %w", err)
}
dsn := m.buildDSN(password, *m.Database)
db, err := sql.Open("mysql", dsn)
if err != nil {
return false, fmt.Errorf("failed to connect to database: %w", err)
return false, nil, fmt.Errorf("failed to connect to database: %w", err)
}
defer func() {
if closeErr := db.Close(); closeErr != nil {
@@ -205,33 +252,44 @@ func (m *MariadbDatabase) IsUserReadOnly(
rows, err := db.QueryContext(ctx, "SHOW GRANTS FOR CURRENT_USER()")
if err != nil {
return false, fmt.Errorf("failed to check grants: %w", err)
return false, nil, fmt.Errorf("failed to check grants: %w", err)
}
defer func() { _ = rows.Close() }()
writePrivileges := []string{
"INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER",
"INDEX", "GRANT OPTION", "ALL PRIVILEGES", "SUPER",
"EXECUTE", "FILE", "RELOAD", "SHUTDOWN", "CREATE ROUTINE",
"ALTER ROUTINE", "CREATE USER",
"CREATE TABLESPACE", "DELETE HISTORY", "REFERENCES",
}
detectedPrivileges := make(map[string]bool)
for rows.Next() {
var grant string
if err := rows.Scan(&grant); err != nil {
return false, fmt.Errorf("failed to scan grant: %w", err)
return false, nil, fmt.Errorf("failed to scan grant: %w", err)
}
for _, priv := range writePrivileges {
if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) {
return false, nil
detectedPrivileges[priv] = true
}
}
}
if err := rows.Err(); err != nil {
return false, fmt.Errorf("error iterating grants: %w", err)
return false, nil, fmt.Errorf("error iterating grants: %w", err)
}
return true, nil
privileges := make([]string, 0, len(detectedPrivileges))
for priv := range detectedPrivileges {
privileges = append(privileges, priv)
}
isReadOnly := len(privileges) == 0
return isReadOnly, privileges, nil
}
func (m *MariadbDatabase) CreateReadOnlyUser(
@@ -261,7 +319,7 @@ func (m *MariadbDatabase) CreateReadOnlyUser(
for attempt := range maxRetries {
// MariaDB 5.5 has a 16-character username limit, use shorter prefix
newUsername := fmt.Sprintf("pgs-%s", uuid.New().String()[:8])
newPassword := uuid.New().String()
newPassword := encryption.GenerateComplexPassword()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
@@ -326,10 +384,31 @@ func (m *MariadbDatabase) CreateReadOnlyUser(
return "", "", errors.New("failed to generate unique username after 3 attempts")
}
func (m *MariadbDatabase) HasPrivilege(priv string) bool {
return HasPrivilege(m.Privileges, priv)
}
func HasPrivilege(privileges, priv string) bool {
for _, p := range strings.Split(privileges, ",") {
if strings.TrimSpace(p) == priv {
return true
}
}
return false
}
func (m *MariadbDatabase) buildDSN(password string, database string) string {
tlsConfig := "false"
if m.IsHttps {
tlsConfig = "true"
err := mysql.RegisterTLSConfig("mariadb-skip-verify", &tls.Config{
InsecureSkipVerify: true,
})
if err != nil {
// Config might already be registered, which is fine
_ = err
}
tlsConfig = "mariadb-skip-verify"
}
return fmt.Sprintf(
@@ -420,6 +499,104 @@ func mapMariadb11xVersion(minor string) (tools.MariadbVersion, error) {
}
}
// detectPrivileges detects backup-related privileges and returns them as comma-separated string
func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string, error) {
rows, err := db.QueryContext(ctx, "SHOW GRANTS FOR CURRENT_USER()")
if err != nil {
return "", fmt.Errorf("failed to check grants: %w", err)
}
defer func() { _ = rows.Close() }()
backupPrivileges := []string{
"SELECT", "SHOW VIEW", "LOCK TABLES", "TRIGGER", "EVENT",
}
detectedPrivileges := make(map[string]bool)
hasProcess := false
hasAllPrivileges := false
dbPatternStr := fmt.Sprintf(
`(?i)ON\s+[\x60'"]?%s[\x60'"]?\s*\.\s*\*`,
regexp.QuoteMeta(database),
)
dbPattern := regexp.MustCompile(dbPatternStr)
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\s*\.\s*\*`)
allPrivilegesPattern := regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`)
for rows.Next() {
var grant string
if err := rows.Scan(&grant); err != nil {
return "", fmt.Errorf("failed to scan grant: %w", err)
}
isRelevantGrant := globalPattern.MatchString(grant) || dbPattern.MatchString(grant)
if allPrivilegesPattern.MatchString(grant) && isRelevantGrant {
hasAllPrivileges = true
}
if isRelevantGrant {
for _, priv := range backupPrivileges {
privPattern := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(priv) + `\b`)
if privPattern.MatchString(grant) {
detectedPrivileges[priv] = true
}
}
}
if globalPattern.MatchString(grant) {
processPattern := regexp.MustCompile(`(?i)\bPROCESS\b`)
if processPattern.MatchString(grant) {
hasProcess = true
}
}
}
if err := rows.Err(); err != nil {
return "", fmt.Errorf("error iterating grants: %w", err)
}
if hasAllPrivileges {
for _, priv := range backupPrivileges {
detectedPrivileges[priv] = true
}
hasProcess = true
}
privileges := make([]string, 0, len(detectedPrivileges)+1)
for priv := range detectedPrivileges {
privileges = append(privileges, priv)
}
if hasProcess {
privileges = append(privileges, "PROCESS")
}
sort.Strings(privileges)
return strings.Join(privileges, ","), nil
}
// checkBackupPermissions verifies the user has sufficient privileges for mariadb-dump backup.
// Required: SELECT, SHOW VIEW
func checkBackupPermissions(privileges string) error {
requiredPrivileges := []string{"SELECT", "SHOW VIEW"}
var missingPrivileges []string
for _, priv := range requiredPrivileges {
if !HasPrivilege(privileges, priv) {
missingPrivileges = append(missingPrivileges, priv)
}
}
if len(missingPrivileges) > 0 {
return fmt.Errorf(
"insufficient permissions for backup. Missing: %s. Required: SELECT, SHOW VIEW",
strings.Join(missingPrivileges, ", "),
)
}
return nil
}
func decryptPasswordIfNeeded(
password string,
encryptor encryption.FieldEncryptor,

View File

@@ -0,0 +1,757 @@
package mariadb
import (
"context"
"fmt"
"log/slog"
"os"
"strconv"
"strings"
"testing"
_ "github.com/go-sql-driver/mysql"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
"databasus-backend/internal/util/tools"
)
func Test_TestConnection_InsufficientPermissions_ReturnsError(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MariadbVersion
port string
}{
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMariadbContainer(t, tc.port, tc.version)
defer container.DB.Close()
_, err := container.DB.Exec(`DROP TABLE IF EXISTS permission_test`)
assert.NoError(t, err)
_, err = container.DB.Exec(`CREATE TABLE permission_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)`)
assert.NoError(t, err)
_, err = container.DB.Exec(`INSERT INTO permission_test (data) VALUES ('test1')`)
assert.NoError(t, err)
limitedUsername := fmt.Sprintf("limited_%s", uuid.New().String()[:8])
limitedPassword := "limitedpassword123"
_, err = container.DB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
limitedUsername,
limitedPassword,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
"GRANT SELECT ON `%s`.* TO '%s'@'%%'",
container.Database,
limitedUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer dropUserSafe(container.DB, limitedUsername)
mariadbModel := &MariadbDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: limitedUsername,
Password: limitedPassword,
Database: &container.Database,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mariadbModel.TestConnection(logger, nil, uuid.New())
assert.Error(t, err)
assert.Contains(t, err.Error(), "insufficient permissions")
})
}
}
func Test_TestConnection_SufficientPermissions_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MariadbVersion
port string
}{
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMariadbContainer(t, tc.port, tc.version)
defer container.DB.Close()
_, err := container.DB.Exec(`DROP TABLE IF EXISTS backup_test`)
assert.NoError(t, err)
_, err = container.DB.Exec(`CREATE TABLE backup_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)`)
assert.NoError(t, err)
_, err = container.DB.Exec(`INSERT INTO backup_test (data) VALUES ('test1')`)
assert.NoError(t, err)
backupUsername := fmt.Sprintf("backup_%s", uuid.New().String()[:8])
backupPassword := "backuppassword123"
_, err = container.DB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
backupUsername,
backupPassword,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
"GRANT SELECT, SHOW VIEW, LOCK TABLES, TRIGGER, EVENT ON `%s`.* TO '%s'@'%%'",
container.Database,
backupUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
"GRANT PROCESS ON *.* TO '%s'@'%%'",
backupUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer dropUserSafe(container.DB, backupUsername)
mariadbModel := &MariadbDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: backupUsername,
Password: backupPassword,
Database: &container.Database,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mariadbModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
})
}
}
func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MariadbVersion
port string
}{
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMariadbContainer(t, tc.port, tc.version)
defer container.DB.Close()
mariadbModel := createMariadbModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
isReadOnly, privileges, err := mariadbModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.False(t, isReadOnly, "Root user should not be read-only")
assert.NotEmpty(t, privileges, "Root user should have privileges")
})
}
}
func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) {
env := config.GetEnv()
container := connectToMariadbContainer(t, env.TestMariadb1011Port, tools.MariadbVersion1011)
defer container.DB.Close()
_, err := container.DB.Exec(`DROP TABLE IF EXISTS readonly_check_test`)
assert.NoError(t, err)
_, err = container.DB.Exec(`CREATE TABLE readonly_check_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)`)
assert.NoError(t, err)
_, err = container.DB.Exec(`INSERT INTO readonly_check_test (data) VALUES ('test1')`)
assert.NoError(t, err)
mariadbModel := createMariadbModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
readOnlyModel := &MariadbDatabase{
Version: mariadbModel.Version,
Host: mariadbModel.Host,
Port: mariadbModel.Port,
Username: username,
Password: password,
Database: mariadbModel.Database,
IsHttps: false,
}
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.True(t, isReadOnly, "Read-only user should be read-only")
assert.Empty(t, privileges, "Read-only user should have no write privileges")
dropUserSafe(container.DB, username)
}
func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MariadbVersion
port string
}{
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMariadbContainer(t, tc.port, tc.version)
defer container.DB.Close()
_, err := container.DB.Exec(`DROP TABLE IF EXISTS readonly_test`)
assert.NoError(t, err)
_, err = container.DB.Exec(`DROP TABLE IF EXISTS hack_table`)
assert.NoError(t, err)
_, err = container.DB.Exec(`DROP TABLE IF EXISTS future_table`)
assert.NoError(t, err)
_, err = container.DB.Exec(`
CREATE TABLE readonly_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)
`)
assert.NoError(t, err)
_, err = container.DB.Exec(
`INSERT INTO readonly_test (data) VALUES ('test1'), ('test2')`,
)
assert.NoError(t, err)
mariadbModel := createMariadbModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.NotEmpty(t, username)
assert.NotEmpty(t, password)
assert.True(t, strings.HasPrefix(username, "pgs-"))
if err != nil {
return
}
readOnlyModel := &MariadbDatabase{
Version: mariadbModel.Version,
Host: mariadbModel.Host,
Port: mariadbModel.Port,
Username: username,
Password: password,
Database: mariadbModel.Database,
IsHttps: false,
}
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
ctx,
logger,
nil,
uuid.New(),
)
assert.NoError(t, err)
assert.True(t, isReadOnly, "Created user should be read-only")
assert.Empty(t, privileges, "Read-only user should have no write privileges")
readOnlyDSN := fmt.Sprintf(
"%s:%s@tcp(%s:%d)/%s?parseTime=true",
username,
password,
container.Host,
container.Port,
container.Database,
)
readOnlyConn, err := sqlx.Connect("mysql", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
var count int
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM readonly_test")
assert.NoError(t, err)
assert.Equal(t, 2, count)
_, err = readOnlyConn.Exec("INSERT INTO readonly_test (data) VALUES ('should-fail')")
assert.Error(t, err)
assert.Contains(t, strings.ToLower(err.Error()), "denied")
_, err = readOnlyConn.Exec("UPDATE readonly_test SET data = 'hacked' WHERE id = 1")
assert.Error(t, err)
assert.Contains(t, strings.ToLower(err.Error()), "denied")
_, err = readOnlyConn.Exec("DELETE FROM readonly_test WHERE id = 1")
assert.Error(t, err)
assert.Contains(t, strings.ToLower(err.Error()), "denied")
_, err = readOnlyConn.Exec("CREATE TABLE hack_table (id INT)")
assert.Error(t, err)
assert.Contains(t, strings.ToLower(err.Error()), "denied")
dropUserSafe(container.DB, username)
})
}
}
func Test_ReadOnlyUser_FutureTables_NoSelectPermission(t *testing.T) {
env := config.GetEnv()
container := connectToMariadbContainer(t, env.TestMariadb1011Port, tools.MariadbVersion1011)
defer container.DB.Close()
mariadbModel := createMariadbModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
_, err = container.DB.Exec(`DROP TABLE IF EXISTS future_table`)
assert.NoError(t, err)
_, err = container.DB.Exec(`
CREATE TABLE future_table (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)
`)
assert.NoError(t, err)
_, err = container.DB.Exec(`INSERT INTO future_table (data) VALUES ('future_data')`)
assert.NoError(t, err)
readOnlyDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
username, password, container.Host, container.Port, container.Database)
readOnlyConn, err := sqlx.Connect("mysql", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
var data string
err = readOnlyConn.Get(&data, "SELECT data FROM future_table LIMIT 1")
assert.NoError(t, err)
assert.Equal(t, "future_data", data)
dropUserSafe(container.DB, username)
}
func Test_CreateReadOnlyUser_DatabaseNameWithDash_Success(t *testing.T) {
env := config.GetEnv()
container := connectToMariadbContainer(t, env.TestMariadb1011Port, tools.MariadbVersion1011)
defer container.DB.Close()
dashDbName := "test-db-with-dash"
_, err := container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", dashDbName))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", dashDbName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", dashDbName))
}()
dashDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
container.Username, container.Password, container.Host, container.Port, dashDbName)
dashDB, err := sqlx.Connect("mysql", dashDSN)
assert.NoError(t, err)
defer dashDB.Close()
_, err = dashDB.Exec(`
CREATE TABLE dash_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)
`)
assert.NoError(t, err)
_, err = dashDB.Exec(`INSERT INTO dash_test (data) VALUES ('test1'), ('test2')`)
assert.NoError(t, err)
mariadbModel := &MariadbDatabase{
Version: tools.MariadbVersion1011,
Host: container.Host,
Port: container.Port,
Username: container.Username,
Password: container.Password,
Database: &dashDbName,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.NotEmpty(t, username)
assert.NotEmpty(t, password)
assert.True(t, strings.HasPrefix(username, "pgs-"))
readOnlyDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
username, password, container.Host, container.Port, dashDbName)
readOnlyConn, err := sqlx.Connect("mysql", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
var count int
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM dash_test")
assert.NoError(t, err)
assert.Equal(t, 2, count)
_, err = readOnlyConn.Exec("INSERT INTO dash_test (data) VALUES ('should-fail')")
assert.Error(t, err)
assert.Contains(t, strings.ToLower(err.Error()), "denied")
dropUserSafe(dashDB, username)
}
func Test_ReadOnlyUser_CannotDropOrAlterTables(t *testing.T) {
env := config.GetEnv()
container := connectToMariadbContainer(t, env.TestMariadb1011Port, tools.MariadbVersion1011)
defer container.DB.Close()
_, err := container.DB.Exec(`DROP TABLE IF EXISTS drop_test`)
assert.NoError(t, err)
_, err = container.DB.Exec(`
CREATE TABLE drop_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)
`)
assert.NoError(t, err)
_, err = container.DB.Exec(`INSERT INTO drop_test (data) VALUES ('test1')`)
assert.NoError(t, err)
mariadbModel := createMariadbModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
readOnlyDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
username, password, container.Host, container.Port, container.Database)
readOnlyConn, err := sqlx.Connect("mysql", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
_, err = readOnlyConn.Exec("DROP TABLE drop_test")
assert.Error(t, err)
assert.Contains(t, strings.ToLower(err.Error()), "denied")
_, err = readOnlyConn.Exec("ALTER TABLE drop_test ADD COLUMN new_col VARCHAR(100)")
assert.Error(t, err)
assert.Contains(t, strings.ToLower(err.Error()), "denied")
_, err = readOnlyConn.Exec("TRUNCATE TABLE drop_test")
assert.Error(t, err)
assert.Contains(t, strings.ToLower(err.Error()), "denied")
dropUserSafe(container.DB, username)
}
func Test_TestConnection_DatabaseSpecificPrivilegesWithGlobalProcess_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MariadbVersion
port string
}{
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMariadbContainer(t, tc.port, tc.version)
defer container.DB.Close()
_, err := container.DB.Exec(`DROP TABLE IF EXISTS privilege_test`)
assert.NoError(t, err)
_, err = container.DB.Exec(`CREATE TABLE privilege_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)`)
assert.NoError(t, err)
_, err = container.DB.Exec(`INSERT INTO privilege_test (data) VALUES ('test1')`)
assert.NoError(t, err)
specificUsername := fmt.Sprintf("spec_%s", uuid.New().String()[:8])
specificPassword := "specificpass123"
_, err = container.DB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
specificUsername,
specificPassword,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
"GRANT SELECT, SHOW VIEW ON %s.* TO '%s'@'%%'",
container.Database,
specificUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
"GRANT PROCESS ON *.* TO '%s'@'%%'",
specificUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer dropUserSafe(container.DB, specificUsername)
mariadbModel := &MariadbDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: specificUsername,
Password: specificPassword,
Database: &container.Database,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mariadbModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
})
}
}
func Test_TestConnection_DatabaseWithUnderscores_Success(t *testing.T) {
env := config.GetEnv()
container := connectToMariadbContainer(t, env.TestMariadb1011Port, tools.MariadbVersion1011)
defer container.DB.Close()
underscoreDbName := "test_db_name"
_, err := container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", underscoreDbName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName))
}()
underscoreDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
container.Username, container.Password, container.Host, container.Port, underscoreDbName)
underscoreDB, err := sqlx.Connect("mysql", underscoreDSN)
assert.NoError(t, err)
defer underscoreDB.Close()
_, err = underscoreDB.Exec(`
CREATE TABLE underscore_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)
`)
assert.NoError(t, err)
_, err = underscoreDB.Exec(`INSERT INTO underscore_test (data) VALUES ('test1')`)
assert.NoError(t, err)
underscoreUsername := fmt.Sprintf("under%s", uuid.New().String()[:8])
underscorePassword := "underscorepass123"
_, err = underscoreDB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
underscoreUsername,
underscorePassword,
))
assert.NoError(t, err)
_, err = underscoreDB.Exec(fmt.Sprintf(
"GRANT SELECT, SHOW VIEW ON `%s`.* TO '%s'@'%%'",
underscoreDbName,
underscoreUsername,
))
assert.NoError(t, err)
_, err = underscoreDB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer dropUserSafe(underscoreDB, underscoreUsername)
mariadbModel := &MariadbDatabase{
Version: tools.MariadbVersion1011,
Host: container.Host,
Port: container.Port,
Username: underscoreUsername,
Password: underscorePassword,
Database: &underscoreDbName,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mariadbModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
}
type MariadbContainer struct {
Host string
Port int
Username string
Password string
Database string
Version tools.MariadbVersion
DB *sqlx.DB
}
func connectToMariadbContainer(
t *testing.T,
port string,
version tools.MariadbVersion,
) *MariadbContainer {
if port == "" {
t.Skipf("MariaDB port not configured for version %s", version)
}
dbName := "testdb"
host := "127.0.0.1"
username := "root"
password := "rootpassword"
portInt, err := strconv.Atoi(port)
assert.NoError(t, err)
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
username, password, host, portInt, dbName)
db, err := sqlx.Connect("mysql", dsn)
if err != nil {
t.Skipf("Failed to connect to MariaDB %s: %v", version, err)
}
return &MariadbContainer{
Host: host,
Port: portInt,
Username: username,
Password: password,
Database: dbName,
Version: version,
DB: db,
}
}
func createMariadbModel(container *MariadbContainer) *MariadbDatabase {
return &MariadbDatabase{
Version: container.Version,
Host: container.Host,
Port: container.Port,
Username: container.Username,
Password: container.Password,
Database: &container.Database,
IsHttps: false,
}
}
func dropUserSafe(db *sqlx.DB, username string) {
_, _ = db.Exec(fmt.Sprintf("DROP USER '%s'@'%%'", username))
}

View File

@@ -1,387 +0,0 @@
package mariadb
import (
"context"
"fmt"
"log/slog"
"os"
"strconv"
"strings"
"testing"
_ "github.com/go-sql-driver/mysql"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
"databasus-backend/internal/util/tools"
)
func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MariadbVersion
port string
}{
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMariadbContainer(t, tc.port, tc.version)
defer container.DB.Close()
mariadbModel := createMariadbModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
isReadOnly, err := mariadbModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.False(t, isReadOnly, "Root user should not be read-only")
})
}
}
func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MariadbVersion
port string
}{
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMariadbContainer(t, tc.port, tc.version)
defer container.DB.Close()
_, err := container.DB.Exec(`DROP TABLE IF EXISTS readonly_test`)
assert.NoError(t, err)
_, err = container.DB.Exec(`DROP TABLE IF EXISTS hack_table`)
assert.NoError(t, err)
_, err = container.DB.Exec(`DROP TABLE IF EXISTS future_table`)
assert.NoError(t, err)
_, err = container.DB.Exec(`
CREATE TABLE readonly_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)
`)
assert.NoError(t, err)
_, err = container.DB.Exec(
`INSERT INTO readonly_test (data) VALUES ('test1'), ('test2')`,
)
assert.NoError(t, err)
mariadbModel := createMariadbModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.NotEmpty(t, username)
assert.NotEmpty(t, password)
assert.True(t, strings.HasPrefix(username, "pgs-"))
if err != nil {
return
}
readOnlyModel := &MariadbDatabase{
Version: mariadbModel.Version,
Host: mariadbModel.Host,
Port: mariadbModel.Port,
Username: username,
Password: password,
Database: mariadbModel.Database,
IsHttps: false,
}
isReadOnly, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.True(t, isReadOnly, "Created user should be read-only")
readOnlyDSN := fmt.Sprintf(
"%s:%s@tcp(%s:%d)/%s?parseTime=true",
username,
password,
container.Host,
container.Port,
container.Database,
)
readOnlyConn, err := sqlx.Connect("mysql", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
var count int
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM readonly_test")
assert.NoError(t, err)
assert.Equal(t, 2, count)
_, err = readOnlyConn.Exec("INSERT INTO readonly_test (data) VALUES ('should-fail')")
assert.Error(t, err)
assert.Contains(t, strings.ToLower(err.Error()), "denied")
_, err = readOnlyConn.Exec("UPDATE readonly_test SET data = 'hacked' WHERE id = 1")
assert.Error(t, err)
assert.Contains(t, strings.ToLower(err.Error()), "denied")
_, err = readOnlyConn.Exec("DELETE FROM readonly_test WHERE id = 1")
assert.Error(t, err)
assert.Contains(t, strings.ToLower(err.Error()), "denied")
_, err = readOnlyConn.Exec("CREATE TABLE hack_table (id INT)")
assert.Error(t, err)
assert.Contains(t, strings.ToLower(err.Error()), "denied")
dropUserSafe(container.DB, username)
})
}
}
func Test_ReadOnlyUser_FutureTables_NoSelectPermission(t *testing.T) {
env := config.GetEnv()
container := connectToMariadbContainer(t, env.TestMariadb1011Port, tools.MariadbVersion1011)
defer container.DB.Close()
mariadbModel := createMariadbModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
_, err = container.DB.Exec(`DROP TABLE IF EXISTS future_table`)
assert.NoError(t, err)
_, err = container.DB.Exec(`
CREATE TABLE future_table (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)
`)
assert.NoError(t, err)
_, err = container.DB.Exec(`INSERT INTO future_table (data) VALUES ('future_data')`)
assert.NoError(t, err)
readOnlyDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
username, password, container.Host, container.Port, container.Database)
readOnlyConn, err := sqlx.Connect("mysql", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
var data string
err = readOnlyConn.Get(&data, "SELECT data FROM future_table LIMIT 1")
assert.NoError(t, err)
assert.Equal(t, "future_data", data)
dropUserSafe(container.DB, username)
}
func Test_CreateReadOnlyUser_DatabaseNameWithDash_Success(t *testing.T) {
env := config.GetEnv()
container := connectToMariadbContainer(t, env.TestMariadb1011Port, tools.MariadbVersion1011)
defer container.DB.Close()
dashDbName := "test-db-with-dash"
_, err := container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", dashDbName))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", dashDbName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", dashDbName))
}()
dashDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
container.Username, container.Password, container.Host, container.Port, dashDbName)
dashDB, err := sqlx.Connect("mysql", dashDSN)
assert.NoError(t, err)
defer dashDB.Close()
_, err = dashDB.Exec(`
CREATE TABLE dash_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)
`)
assert.NoError(t, err)
_, err = dashDB.Exec(`INSERT INTO dash_test (data) VALUES ('test1'), ('test2')`)
assert.NoError(t, err)
mariadbModel := &MariadbDatabase{
Version: tools.MariadbVersion1011,
Host: container.Host,
Port: container.Port,
Username: container.Username,
Password: container.Password,
Database: &dashDbName,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.NotEmpty(t, username)
assert.NotEmpty(t, password)
assert.True(t, strings.HasPrefix(username, "pgs-"))
readOnlyDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
username, password, container.Host, container.Port, dashDbName)
readOnlyConn, err := sqlx.Connect("mysql", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
var count int
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM dash_test")
assert.NoError(t, err)
assert.Equal(t, 2, count)
_, err = readOnlyConn.Exec("INSERT INTO dash_test (data) VALUES ('should-fail')")
assert.Error(t, err)
assert.Contains(t, strings.ToLower(err.Error()), "denied")
dropUserSafe(dashDB, username)
}
func Test_ReadOnlyUser_CannotDropOrAlterTables(t *testing.T) {
env := config.GetEnv()
container := connectToMariadbContainer(t, env.TestMariadb1011Port, tools.MariadbVersion1011)
defer container.DB.Close()
_, err := container.DB.Exec(`DROP TABLE IF EXISTS drop_test`)
assert.NoError(t, err)
_, err = container.DB.Exec(`
CREATE TABLE drop_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)
`)
assert.NoError(t, err)
_, err = container.DB.Exec(`INSERT INTO drop_test (data) VALUES ('test1')`)
assert.NoError(t, err)
mariadbModel := createMariadbModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
readOnlyDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
username, password, container.Host, container.Port, container.Database)
readOnlyConn, err := sqlx.Connect("mysql", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
_, err = readOnlyConn.Exec("DROP TABLE drop_test")
assert.Error(t, err)
assert.Contains(t, strings.ToLower(err.Error()), "denied")
_, err = readOnlyConn.Exec("ALTER TABLE drop_test ADD COLUMN new_col VARCHAR(100)")
assert.Error(t, err)
assert.Contains(t, strings.ToLower(err.Error()), "denied")
_, err = readOnlyConn.Exec("TRUNCATE TABLE drop_test")
assert.Error(t, err)
assert.Contains(t, strings.ToLower(err.Error()), "denied")
dropUserSafe(container.DB, username)
}
type MariadbContainer struct {
Host string
Port int
Username string
Password string
Database string
Version tools.MariadbVersion
DB *sqlx.DB
}
func connectToMariadbContainer(
t *testing.T,
port string,
version tools.MariadbVersion,
) *MariadbContainer {
if port == "" {
t.Skipf("MariaDB port not configured for version %s", version)
}
dbName := "testdb"
host := "127.0.0.1"
username := "root"
password := "rootpassword"
portInt, err := strconv.Atoi(port)
assert.NoError(t, err)
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
username, password, host, portInt, dbName)
db, err := sqlx.Connect("mysql", dsn)
if err != nil {
t.Skipf("Failed to connect to MariaDB %s: %v", version, err)
}
return &MariadbContainer{
Host: host,
Port: portInt,
Username: username,
Password: password,
Database: dbName,
Version: version,
DB: db,
}
}
func createMariadbModel(container *MariadbContainer) *MariadbDatabase {
return &MariadbDatabase{
Version: container.Version,
Host: container.Host,
Port: container.Port,
Username: container.Username,
Password: container.Password,
Database: &container.Database,
IsHttps: false,
}
}
func dropUserSafe(db *sqlx.DB, username string) {
// MariaDB 5.5 doesn't support DROP USER IF EXISTS, so we ignore errors
_, _ = db.Exec(fmt.Sprintf("DROP USER '%s'@'%%'", username))
}

View File

@@ -5,7 +5,9 @@ import (
"errors"
"fmt"
"log/slog"
"net/url"
"regexp"
"strings"
"time"
"databasus-backend/internal/util/encryption"
@@ -95,6 +97,16 @@ func (m *MongodbDatabase) TestConnection(
}
m.Version = detectedVersion
if err := checkBackupPermissions(
ctx,
client,
m.Username,
m.Database,
m.AuthDatabase,
); err != nil {
return err
}
return nil
}
@@ -134,14 +146,11 @@ func (m *MongodbDatabase) EncryptSensitiveFields(
return nil
}
func (m *MongodbDatabase) PopulateVersionIfEmpty(
func (m *MongodbDatabase) PopulateDbData(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error {
if m.Version != "" {
return nil
}
return m.PopulateVersion(logger, encryptor, databaseID)
}
@@ -185,10 +194,10 @@ func (m *MongodbDatabase) IsUserReadOnly(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) (bool, error) {
) (bool, []string, error) {
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
if err != nil {
return false, fmt.Errorf("failed to decrypt password: %w", err)
return false, nil, fmt.Errorf("failed to decrypt password: %w", err)
}
uri := m.buildConnectionURI(password)
@@ -196,7 +205,7 @@ func (m *MongodbDatabase) IsUserReadOnly(
clientOptions := options.Client().ApplyURI(uri)
client, err := mongo.Connect(ctx, clientOptions)
if err != nil {
return false, fmt.Errorf("failed to connect to database: %w", err)
return false, nil, fmt.Errorf("failed to connect to database: %w", err)
}
defer func() {
if disconnectErr := client.Disconnect(ctx); disconnectErr != nil {
@@ -218,44 +227,153 @@ func (m *MongodbDatabase) IsUserReadOnly(
}},
}).Decode(&result)
if err != nil {
return false, fmt.Errorf("failed to get user info: %w", err)
return false, nil, fmt.Errorf("failed to get user info: %w", err)
}
writeRoles := []string{
"readWrite", "readWriteAnyDatabase", "dbAdmin", "dbAdminAnyDatabase",
"userAdmin", "userAdminAnyDatabase", "clusterAdmin", "root",
"dbOwner", "backup", "restore",
writeRoles := map[string]bool{
"readWrite": true,
"readWriteAnyDatabase": true,
"dbAdmin": true,
"dbAdminAnyDatabase": true,
"userAdmin": true,
"userAdminAnyDatabase": true,
"clusterAdmin": true,
"clusterManager": true,
"hostManager": true,
"root": true,
"dbOwner": true,
"restore": true,
"__system": true,
}
// Roles that are read-only for our backup purposes
// The "backup" role has insert/update on mms.backup collection but is needed for mongodump
readOnlyRoles := map[string]bool{
"read": true,
"backup": true,
}
writeActions := map[string]bool{
"insert": true,
"update": true,
"remove": true,
"createCollection": true,
"dropCollection": true,
"createIndex": true,
"dropIndex": true,
"convertToCapped": true,
"dropDatabase": true,
"renameCollection": true,
"createUser": true,
"dropUser": true,
"updateUser": true,
"grantRole": true,
"revokeRole": true,
"dropRole": true,
"createRole": true,
"updateRole": true,
"enableSharding": true,
"shardCollection": true,
"addShard": true,
"removeShard": true,
"shutdown": true,
"replSetReconfig": true,
"replSetStateChange": true,
}
var detectedRoles []string
users, ok := result["users"].(bson.A)
if !ok || len(users) == 0 {
return true, nil
return true, detectedRoles, nil
}
user, ok := users[0].(bson.M)
if !ok {
return true, nil
return true, detectedRoles, nil
}
roles, ok := user["roles"].(bson.A)
if !ok {
return true, nil
return true, detectedRoles, nil
}
// Collect all role names and check for write roles
for _, roleDoc := range roles {
role, ok := roleDoc.(bson.M)
if !ok {
continue
}
roleName, _ := role["role"].(string)
for _, writeRole := range writeRoles {
if roleName == writeRole {
return false, nil
if roleName != "" {
detectedRoles = append(detectedRoles, roleName)
}
}
// Check if any detected role is a write role
for _, roleName := range detectedRoles {
if writeRoles[roleName] {
return false, detectedRoles, nil
}
}
// If all roles are known read-only roles (read, backup), skip inherited privilege check
allRolesReadOnly := true
for _, roleName := range detectedRoles {
if !readOnlyRoles[roleName] {
allRolesReadOnly = false
break
}
}
if allRolesReadOnly && len(detectedRoles) > 0 {
return true, detectedRoles, nil
}
// Check inherited privileges for custom roles
var privResult bson.M
err = adminDB.RunCommand(ctx, bson.D{
{Key: "usersInfo", Value: bson.D{
{Key: "user", Value: m.Username},
{Key: "db", Value: authDB},
}},
{Key: "showPrivileges", Value: true},
}).Decode(&privResult)
if err != nil {
return false, nil, fmt.Errorf("failed to get user privileges: %w", err)
}
privUsers, ok := privResult["users"].(bson.A)
if !ok || len(privUsers) == 0 {
return true, detectedRoles, nil
}
privUser, ok := privUsers[0].(bson.M)
if !ok {
return true, detectedRoles, nil
}
// Check inheritedPrivileges for write actions
inheritedPrivileges, ok := privUser["inheritedPrivileges"].(bson.A)
if ok {
for _, privDoc := range inheritedPrivileges {
priv, ok := privDoc.(bson.M)
if !ok {
continue
}
actions, ok := priv["actions"].(bson.A)
if !ok {
continue
}
for _, action := range actions {
actionStr, ok := action.(string)
if ok && writeActions[actionStr] {
return false, detectedRoles, nil
}
}
}
}
return true, nil
return true, detectedRoles, nil
}
func (m *MongodbDatabase) CreateReadOnlyUser(
@@ -290,7 +408,7 @@ func (m *MongodbDatabase) CreateReadOnlyUser(
maxRetries := 3
for attempt := range maxRetries {
newUsername := fmt.Sprintf("databasus-%s", uuid.New().String()[:8])
newPassword := uuid.New().String()
newPassword := encryption.GenerateComplexPassword()
adminDB := client.Database(authDB)
err = adminDB.RunCommand(ctx, bson.D{
@@ -332,20 +450,20 @@ func (m *MongodbDatabase) buildConnectionURI(password string) string {
authDB = "admin"
}
tlsOption := "false"
tlsParams := ""
if m.IsHttps {
tlsOption = "true"
tlsParams = "&tls=true&tlsInsecure=true"
}
return fmt.Sprintf(
"mongodb://%s:%s@%s:%d/%s?authSource=%s&tls=%s&connectTimeoutMS=15000",
m.Username,
password,
"mongodb://%s:%s@%s:%d/%s?authSource=%s&connectTimeoutMS=15000%s",
url.QueryEscape(m.Username),
url.QueryEscape(password),
m.Host,
m.Port,
m.Database,
authDB,
tlsOption,
tlsParams,
)
}
@@ -356,19 +474,19 @@ func (m *MongodbDatabase) BuildMongodumpURI(password string) string {
authDB = "admin"
}
tlsOption := "false"
tlsParams := ""
if m.IsHttps {
tlsOption = "true"
tlsParams = "&tls=true&tlsInsecure=true"
}
return fmt.Sprintf(
"mongodb://%s:%s@%s:%d/?authSource=%s&tls=%s&connectTimeoutMS=15000",
m.Username,
password,
"mongodb://%s:%s@%s:%d/?authSource=%s&connectTimeoutMS=15000%s",
url.QueryEscape(m.Username),
url.QueryEscape(password),
m.Host,
m.Port,
authDB,
tlsOption,
tlsParams,
)
}
@@ -413,6 +531,128 @@ func detectMongodbVersion(ctx context.Context, client *mongo.Client) (tools.Mong
}
}
// checkBackupPermissions verifies the user has sufficient privileges for mongodump backup.
// Required: 'read' role on target database OR 'backup' role on admin OR 'readAnyDatabase' role.
func checkBackupPermissions(
ctx context.Context,
client *mongo.Client,
username, database, authDatabase string,
) error {
authDB := authDatabase
if authDB == "" {
authDB = "admin"
}
adminDB := client.Database(authDB)
var result bson.M
err := adminDB.RunCommand(ctx, bson.D{
{Key: "usersInfo", Value: bson.D{
{Key: "user", Value: username},
{Key: "db", Value: authDB},
}},
{Key: "showPrivileges", Value: true},
}).Decode(&result)
if err != nil {
return fmt.Errorf("failed to get user info: %w", err)
}
users, ok := result["users"].(bson.A)
if !ok || len(users) == 0 {
return errors.New("insufficient permissions for backup. User not found")
}
user, ok := users[0].(bson.M)
if !ok {
return errors.New("insufficient permissions for backup. Could not parse user info")
}
// Check roles for backup permissions
roles, ok := user["roles"].(bson.A)
if !ok {
return errors.New("insufficient permissions for backup. No roles assigned")
}
backupRoles := map[string]bool{
"backup": true,
"root": true,
"readAnyDatabase": true,
"dbOwner": true,
"__system": true,
"clusterAdmin": true,
"readWriteAnyDatabase": true,
}
var userRoles []string
hasBackupRole := false
hasReadOnTargetDB := false
for _, roleDoc := range roles {
role, ok := roleDoc.(bson.M)
if !ok {
continue
}
roleName, _ := role["role"].(string)
roleDB, _ := role["db"].(string)
if roleName != "" {
userRoles = append(userRoles, roleName)
}
if backupRoles[roleName] {
hasBackupRole = true
}
if roleName == "read" && (roleDB == database || roleDB == "") {
hasReadOnTargetDB = true
}
if roleName == "readWrite" && (roleDB == database || roleDB == "") {
hasReadOnTargetDB = true
}
}
if hasBackupRole || hasReadOnTargetDB {
return nil
}
// Check inherited privileges for 'find' action on target database
inheritedPrivileges, ok := user["inheritedPrivileges"].(bson.A)
if ok {
for _, privDoc := range inheritedPrivileges {
priv, ok := privDoc.(bson.M)
if !ok {
continue
}
resource, ok := priv["resource"].(bson.M)
if !ok {
continue
}
resourceDB, _ := resource["db"].(string)
resourceCluster, _ := resource["cluster"].(bool)
isTargetDB := resourceDB == database || resourceDB == "" || resourceCluster
actions, ok := priv["actions"].(bson.A)
if !ok {
continue
}
for _, action := range actions {
actionStr, ok := action.(string)
if ok && actionStr == "find" && isTargetDB {
return nil
}
}
}
}
return fmt.Errorf(
"insufficient permissions for backup. Current roles: %s. Required: 'read' role on database '%s' OR 'backup' role on admin OR 'readAnyDatabase' role",
strings.Join(userRoles, ", "),
database,
)
}
func decryptPasswordIfNeeded(
password string,
encryptor encryption.FieldEncryptor,

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"log/slog"
"net/url"
"os"
"strconv"
"strings"
@@ -19,6 +20,138 @@ import (
"databasus-backend/internal/util/tools"
)
func Test_TestConnection_InsufficientPermissions_ReturnsError(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MongodbVersion
port string
}{
{"MongoDB 4.0", tools.MongodbVersion4, env.TestMongodb40Port},
{"MongoDB 4.2", tools.MongodbVersion4, env.TestMongodb42Port},
{"MongoDB 4.4", tools.MongodbVersion4, env.TestMongodb44Port},
{"MongoDB 5.0", tools.MongodbVersion5, env.TestMongodb50Port},
{"MongoDB 6.0", tools.MongodbVersion6, env.TestMongodb60Port},
{"MongoDB 7.0", tools.MongodbVersion7, env.TestMongodb70Port},
{"MongoDB 8.2", tools.MongodbVersion8, env.TestMongodb82Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMongodbContainer(t, tc.port, tc.version)
defer container.Client.Disconnect(context.Background())
ctx := context.Background()
db := container.Client.Database(container.Database)
_ = db.Collection("permission_test").Drop(ctx)
_, err := db.Collection("permission_test").InsertOne(ctx, bson.M{"data": "test1"})
assert.NoError(t, err)
limitedUsername := fmt.Sprintf("limited_%s", uuid.New().String()[:8])
limitedPassword := "limitedpassword123"
adminDB := container.Client.Database(container.AuthDatabase)
err = adminDB.RunCommand(ctx, bson.D{
{Key: "createUser", Value: limitedUsername},
{Key: "pwd", Value: limitedPassword},
{Key: "roles", Value: bson.A{}},
}).Err()
assert.NoError(t, err)
defer dropUserSafe(container.Client, limitedUsername, container.AuthDatabase)
mongodbModel := &MongodbDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: limitedUsername,
Password: limitedPassword,
Database: container.Database,
AuthDatabase: container.AuthDatabase,
IsHttps: false,
CpuCount: 1,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mongodbModel.TestConnection(logger, nil, uuid.New())
assert.Error(t, err)
assert.Contains(t, err.Error(), "insufficient permissions")
})
}
}
func Test_TestConnection_SufficientPermissions_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MongodbVersion
port string
}{
{"MongoDB 4.0", tools.MongodbVersion4, env.TestMongodb40Port},
{"MongoDB 4.2", tools.MongodbVersion4, env.TestMongodb42Port},
{"MongoDB 4.4", tools.MongodbVersion4, env.TestMongodb44Port},
{"MongoDB 5.0", tools.MongodbVersion5, env.TestMongodb50Port},
{"MongoDB 6.0", tools.MongodbVersion6, env.TestMongodb60Port},
{"MongoDB 7.0", tools.MongodbVersion7, env.TestMongodb70Port},
{"MongoDB 8.2", tools.MongodbVersion8, env.TestMongodb82Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMongodbContainer(t, tc.port, tc.version)
defer container.Client.Disconnect(context.Background())
ctx := context.Background()
db := container.Client.Database(container.Database)
_ = db.Collection("backup_test").Drop(ctx)
_, err := db.Collection("backup_test").InsertOne(ctx, bson.M{"data": "test1"})
assert.NoError(t, err)
backupUsername := fmt.Sprintf("backup_%s", uuid.New().String()[:8])
backupPassword := "backuppassword123"
adminDB := container.Client.Database(container.AuthDatabase)
err = adminDB.RunCommand(ctx, bson.D{
{Key: "createUser", Value: backupUsername},
{Key: "pwd", Value: backupPassword},
{Key: "roles", Value: bson.A{
bson.D{
{Key: "role", Value: "read"},
{Key: "db", Value: container.Database},
},
}},
}).Err()
assert.NoError(t, err)
defer dropUserSafe(container.Client, backupUsername, container.AuthDatabase)
mongodbModel := &MongodbDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: backupUsername,
Password: backupPassword,
Database: container.Database,
AuthDatabase: container.AuthDatabase,
IsHttps: false,
CpuCount: 1,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mongodbModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
})
}
}
func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
env := config.GetEnv()
cases := []struct {
@@ -46,13 +179,52 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
isReadOnly, err := mongodbModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
isReadOnly, roles, err := mongodbModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.False(t, isReadOnly, "Root user should not be read-only")
assert.NotEmpty(t, roles, "Root user should have roles")
})
}
}
func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) {
env := config.GetEnv()
container := connectToMongodbContainer(t, env.TestMongodb70Port, tools.MongodbVersion7)
defer container.Client.Disconnect(context.Background())
ctx := context.Background()
db := container.Client.Database(container.Database)
_ = db.Collection("readonly_check_test").Drop(ctx)
_, err := db.Collection("readonly_check_test").InsertOne(ctx, bson.M{"data": "test1"})
assert.NoError(t, err)
mongodbModel := createMongodbModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
username, password, err := mongodbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
readOnlyModel := &MongodbDatabase{
Version: mongodbModel.Version,
Host: mongodbModel.Host,
Port: mongodbModel.Port,
Username: username,
Password: password,
Database: mongodbModel.Database,
AuthDatabase: mongodbModel.AuthDatabase,
IsHttps: false,
CpuCount: 1,
}
isReadOnly, roles, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.True(t, isReadOnly, "Read-only user should be read-only")
assert.NotEmpty(t, roles, "Read-only user should have roles (read, backup)")
dropUserSafe(container.Client, username, container.AuthDatabase)
}
func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
env := config.GetEnv()
cases := []struct {
@@ -271,6 +443,7 @@ func createMongodbModel(container *MongodbContainer) *MongodbDatabase {
Database: container.Database,
AuthDatabase: container.AuthDatabase,
IsHttps: false,
CpuCount: 1,
}
}
@@ -281,7 +454,8 @@ func connectWithCredentials(
) *mongo.Client {
uri := fmt.Sprintf(
"mongodb://%s:%s@%s:%d/%s?authSource=%s",
username, password, container.Host, container.Port,
url.QueryEscape(username), url.QueryEscape(password),
container.Host, container.Port,
container.Database, container.AuthDatabase,
)

View File

@@ -2,17 +2,20 @@ package mysql
import (
"context"
"crypto/tls"
"database/sql"
"errors"
"fmt"
"log/slog"
"regexp"
"sort"
"strings"
"time"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/tools"
_ "github.com/go-sql-driver/mysql"
"github.com/go-sql-driver/mysql"
"github.com/google/uuid"
)
@@ -22,12 +25,13 @@ type MysqlDatabase struct {
Version tools.MysqlVersion `json:"version" gorm:"type:text;not null"`
Host string `json:"host" gorm:"type:text;not null"`
Port int `json:"port" gorm:"type:int;not null"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
Database *string `json:"database" gorm:"type:text"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
Host string `json:"host" gorm:"type:text;not null"`
Port int `json:"port" gorm:"type:int;not null"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
Database *string `json:"database" gorm:"type:text"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
Privileges string `json:"privileges" gorm:"column:privileges;type:text;not null;default:''"`
}
func (m *MysqlDatabase) TableName() string {
@@ -93,6 +97,16 @@ func (m *MysqlDatabase) TestConnection(
}
m.Version = detectedVersion
privileges, err := detectPrivileges(ctx, db, *m.Database)
if err != nil {
return err
}
m.Privileges = privileges
if err := checkBackupPermissions(m.Privileges); err != nil {
return err
}
return nil
}
@@ -110,6 +124,7 @@ func (m *MysqlDatabase) Update(incoming *MysqlDatabase) {
m.Username = incoming.Username
m.Database = incoming.Database
m.IsHttps = incoming.IsHttps
m.Privileges = incoming.Privileges
if incoming.Password != "" {
m.Password = incoming.Password
@@ -130,15 +145,48 @@ func (m *MysqlDatabase) EncryptSensitiveFields(
return nil
}
func (m *MysqlDatabase) PopulateVersionIfEmpty(
func (m *MysqlDatabase) PopulateDbData(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error {
if m.Version != "" {
if m.Database == nil || *m.Database == "" {
return nil
}
return m.PopulateVersion(logger, encryptor, databaseID)
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
if err != nil {
return fmt.Errorf("failed to decrypt password: %w", err)
}
dsn := m.buildDSN(password, *m.Database)
db, err := sql.Open("mysql", dsn)
if err != nil {
return fmt.Errorf("failed to connect to database: %w", err)
}
defer func() {
if closeErr := db.Close(); closeErr != nil {
logger.Error("Failed to close connection", "error", closeErr)
}
}()
detectedVersion, err := detectMysqlVersion(ctx, db)
if err != nil {
return err
}
m.Version = detectedVersion
privileges, err := detectPrivileges(ctx, db, *m.Database)
if err != nil {
return err
}
m.Privileges = privileges
return nil
}
func (m *MysqlDatabase) PopulateVersion(
@@ -174,8 +222,8 @@ func (m *MysqlDatabase) PopulateVersion(
if err != nil {
return err
}
m.Version = detectedVersion
return nil
}
@@ -184,17 +232,17 @@ func (m *MysqlDatabase) IsUserReadOnly(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) (bool, error) {
) (bool, []string, error) {
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
if err != nil {
return false, fmt.Errorf("failed to decrypt password: %w", err)
return false, nil, fmt.Errorf("failed to decrypt password: %w", err)
}
dsn := m.buildDSN(password, *m.Database)
db, err := sql.Open("mysql", dsn)
if err != nil {
return false, fmt.Errorf("failed to connect to database: %w", err)
return false, nil, fmt.Errorf("failed to connect to database: %w", err)
}
defer func() {
if closeErr := db.Close(); closeErr != nil {
@@ -204,33 +252,45 @@ func (m *MysqlDatabase) IsUserReadOnly(
rows, err := db.QueryContext(ctx, "SHOW GRANTS FOR CURRENT_USER()")
if err != nil {
return false, fmt.Errorf("failed to check grants: %w", err)
return false, nil, fmt.Errorf("failed to check grants: %w", err)
}
defer func() { _ = rows.Close() }()
writePrivileges := []string{
"INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER",
"INDEX", "GRANT OPTION", "ALL PRIVILEGES", "SUPER",
"EXECUTE", "FILE", "RELOAD", "SHUTDOWN", "CREATE ROUTINE",
"ALTER ROUTINE", "CREATE USER",
"CREATE TABLESPACE", "REFERENCES",
}
detectedPrivileges := make(map[string]bool)
for rows.Next() {
var grant string
if err := rows.Scan(&grant); err != nil {
return false, fmt.Errorf("failed to scan grant: %w", err)
return false, nil, fmt.Errorf("failed to scan grant: %w", err)
}
for _, priv := range writePrivileges {
if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) {
return false, nil
detectedPrivileges[priv] = true
}
}
}
if err := rows.Err(); err != nil {
return false, fmt.Errorf("error iterating grants: %w", err)
return false, nil, fmt.Errorf("error iterating grants: %w", err)
}
return true, nil
privileges := make([]string, 0, len(detectedPrivileges))
for priv := range detectedPrivileges {
privileges = append(privileges, priv)
}
isReadOnly := len(privileges) == 0
return isReadOnly, privileges, nil
}
func (m *MysqlDatabase) CreateReadOnlyUser(
@@ -259,7 +319,7 @@ func (m *MysqlDatabase) CreateReadOnlyUser(
maxRetries := 3
for attempt := range maxRetries {
newUsername := fmt.Sprintf("databasus-%s", uuid.New().String()[:8])
newPassword := uuid.New().String()
newPassword := encryption.GenerateComplexPassword()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
@@ -325,10 +385,32 @@ func (m *MysqlDatabase) CreateReadOnlyUser(
return "", "", errors.New("failed to generate unique username after 3 attempts")
}
func (m *MysqlDatabase) HasPrivilege(priv string) bool {
return HasPrivilege(m.Privileges, priv)
}
func HasPrivilege(privileges, priv string) bool {
for p := range strings.SplitSeq(privileges, ",") {
if strings.TrimSpace(p) == priv {
return true
}
}
return false
}
func (m *MysqlDatabase) buildDSN(password string, database string) string {
tlsConfig := "false"
if m.IsHttps {
tlsConfig = "true"
err := mysql.RegisterTLSConfig("mysql-skip-verify", &tls.Config{
InsecureSkipVerify: true,
})
if err != nil {
// Config might already be registered, which is fine
_ = err
}
tlsConfig = "mysql-skip-verify"
}
return fmt.Sprintf(
@@ -388,6 +470,104 @@ func mapMysql8xVersion(minor string) tools.MysqlVersion {
}
}
// detectPrivileges detects backup-related privileges and returns them as comma-separated string
func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string, error) {
rows, err := db.QueryContext(ctx, "SHOW GRANTS FOR CURRENT_USER()")
if err != nil {
return "", fmt.Errorf("failed to check grants: %w", err)
}
defer func() { _ = rows.Close() }()
backupPrivileges := []string{
"SELECT", "SHOW VIEW", "LOCK TABLES", "TRIGGER", "EVENT",
}
detectedPrivileges := make(map[string]bool)
hasProcess := false
hasAllPrivileges := false
dbPatternStr := fmt.Sprintf(
`(?i)ON\s+[\x60'"]?%s[\x60'"]?\s*\.\s*\*`,
regexp.QuoteMeta(database),
)
dbPattern := regexp.MustCompile(dbPatternStr)
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\s*\.\s*\*`)
allPrivilegesPattern := regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`)
for rows.Next() {
var grant string
if err := rows.Scan(&grant); err != nil {
return "", fmt.Errorf("failed to scan grant: %w", err)
}
isRelevantGrant := globalPattern.MatchString(grant) || dbPattern.MatchString(grant)
if allPrivilegesPattern.MatchString(grant) && isRelevantGrant {
hasAllPrivileges = true
}
if isRelevantGrant {
for _, priv := range backupPrivileges {
privPattern := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(priv) + `\b`)
if privPattern.MatchString(grant) {
detectedPrivileges[priv] = true
}
}
}
if globalPattern.MatchString(grant) {
processPattern := regexp.MustCompile(`(?i)\bPROCESS\b`)
if processPattern.MatchString(grant) {
hasProcess = true
}
}
}
if err := rows.Err(); err != nil {
return "", fmt.Errorf("error iterating grants: %w", err)
}
if hasAllPrivileges {
for _, priv := range backupPrivileges {
detectedPrivileges[priv] = true
}
hasProcess = true
}
privileges := make([]string, 0, len(detectedPrivileges)+1)
for priv := range detectedPrivileges {
privileges = append(privileges, priv)
}
if hasProcess {
privileges = append(privileges, "PROCESS")
}
sort.Strings(privileges)
return strings.Join(privileges, ","), nil
}
// checkBackupPermissions verifies the user has sufficient privileges for mysqldump backup.
// Required: SELECT, SHOW VIEW
func checkBackupPermissions(privileges string) error {
requiredPrivileges := []string{"SELECT", "SHOW VIEW"}
var missingPrivileges []string
for _, priv := range requiredPrivileges {
if !HasPrivilege(privileges, priv) {
missingPrivileges = append(missingPrivileges, priv)
}
}
if len(missingPrivileges) > 0 {
return fmt.Errorf(
"insufficient permissions for backup. Missing: %s. Required: SELECT, SHOW VIEW",
strings.Join(missingPrivileges, ", "),
)
}
return nil
}
func decryptPasswordIfNeeded(
password string,
encryptor encryption.FieldEncryptor,

View File

@@ -18,6 +18,165 @@ import (
"databasus-backend/internal/util/tools"
)
func Test_TestConnection_InsufficientPermissions_ReturnsError(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MysqlVersion
port string
}{
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port},
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port},
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port},
{"MySQL 9", tools.MysqlVersion9, env.TestMysql90Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMysqlContainer(t, tc.port, tc.version)
defer container.DB.Close()
_, err := container.DB.Exec(`DROP TABLE IF EXISTS permission_test`)
assert.NoError(t, err)
_, err = container.DB.Exec(`CREATE TABLE permission_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)`)
assert.NoError(t, err)
_, err = container.DB.Exec(`INSERT INTO permission_test (data) VALUES ('test1')`)
assert.NoError(t, err)
limitedUsername := fmt.Sprintf("limited_%s", uuid.New().String()[:8])
limitedPassword := "limitedpassword123"
_, err = container.DB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
limitedUsername,
limitedPassword,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
"GRANT SELECT ON `%s`.* TO '%s'@'%%'",
container.Database,
limitedUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(
fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", limitedUsername),
)
}()
mysqlModel := &MysqlDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: limitedUsername,
Password: limitedPassword,
Database: &container.Database,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mysqlModel.TestConnection(logger, nil, uuid.New())
assert.Error(t, err)
assert.Contains(t, err.Error(), "insufficient permissions")
})
}
}
func Test_TestConnection_SufficientPermissions_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MysqlVersion
port string
}{
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port},
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port},
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port},
{"MySQL 9", tools.MysqlVersion9, env.TestMysql90Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMysqlContainer(t, tc.port, tc.version)
defer container.DB.Close()
_, err := container.DB.Exec(`DROP TABLE IF EXISTS backup_test`)
assert.NoError(t, err)
_, err = container.DB.Exec(`CREATE TABLE backup_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)`)
assert.NoError(t, err)
_, err = container.DB.Exec(`INSERT INTO backup_test (data) VALUES ('test1')`)
assert.NoError(t, err)
backupUsername := fmt.Sprintf("backup_%s", uuid.New().String()[:8])
backupPassword := "backuppassword123"
_, err = container.DB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
backupUsername,
backupPassword,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
"GRANT SELECT, SHOW VIEW, LOCK TABLES, TRIGGER, EVENT ON `%s`.* TO '%s'@'%%'",
container.Database,
backupUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
"GRANT PROCESS ON *.* TO '%s'@'%%'",
backupUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(
fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", backupUsername),
)
}()
mysqlModel := &MysqlDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: backupUsername,
Password: backupPassword,
Database: &container.Database,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mysqlModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
})
}
}
func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
env := config.GetEnv()
cases := []struct {
@@ -42,13 +201,57 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
isReadOnly, err := mysqlModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
isReadOnly, privileges, err := mysqlModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.False(t, isReadOnly, "Root user should not be read-only")
assert.NotEmpty(t, privileges, "Root user should have privileges")
})
}
}
func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) {
env := config.GetEnv()
container := connectToMysqlContainer(t, env.TestMysql80Port, tools.MysqlVersion80)
defer container.DB.Close()
_, err := container.DB.Exec(`DROP TABLE IF EXISTS readonly_check_test`)
assert.NoError(t, err)
_, err = container.DB.Exec(`CREATE TABLE readonly_check_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)`)
assert.NoError(t, err)
_, err = container.DB.Exec(`INSERT INTO readonly_check_test (data) VALUES ('test1')`)
assert.NoError(t, err)
mysqlModel := createMysqlModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := mysqlModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
readOnlyModel := &MysqlDatabase{
Version: mysqlModel.Version,
Host: mysqlModel.Host,
Port: mysqlModel.Port,
Username: username,
Password: password,
Database: mysqlModel.Database,
IsHttps: false,
}
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.True(t, isReadOnly, "Read-only user should be read-only")
assert.Empty(t, privileges, "Read-only user should have no write privileges")
_, err = container.DB.Exec(fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", username))
assert.NoError(t, err)
}
func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
env := config.GetEnv()
cases := []struct {
@@ -109,9 +312,15 @@ func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
IsHttps: false,
}
isReadOnly, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
ctx,
logger,
nil,
uuid.New(),
)
assert.NoError(t, err)
assert.True(t, isReadOnly, "Created user should be read-only")
assert.Empty(t, privileges, "Read-only user should have no write privileges")
readOnlyDSN := fmt.Sprintf(
"%s:%s@tcp(%s:%d)/%s?parseTime=true",
@@ -309,6 +518,162 @@ func Test_ReadOnlyUser_CannotDropOrAlterTables(t *testing.T) {
assert.NoError(t, err)
}
func Test_TestConnection_DatabaseSpecificPrivilegesWithGlobalProcess_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MysqlVersion
port string
}{
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port},
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port},
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port},
{"MySQL 9", tools.MysqlVersion9, env.TestMysql90Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMysqlContainer(t, tc.port, tc.version)
defer container.DB.Close()
_, err := container.DB.Exec(`DROP TABLE IF EXISTS privilege_test`)
assert.NoError(t, err)
_, err = container.DB.Exec(`CREATE TABLE privilege_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)`)
assert.NoError(t, err)
_, err = container.DB.Exec(`INSERT INTO privilege_test (data) VALUES ('test1')`)
assert.NoError(t, err)
specificUsername := fmt.Sprintf("specific_%s", uuid.New().String()[:8])
specificPassword := "specificpass123"
_, err = container.DB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
specificUsername,
specificPassword,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
"GRANT SELECT, SHOW VIEW ON %s.* TO '%s'@'%%'",
container.Database,
specificUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
"GRANT PROCESS ON *.* TO '%s'@'%%'",
specificUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(
fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", specificUsername),
)
}()
mysqlModel := &MysqlDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: specificUsername,
Password: specificPassword,
Database: &container.Database,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mysqlModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
})
}
}
func Test_TestConnection_DatabaseWithUnderscores_Success(t *testing.T) {
env := config.GetEnv()
container := connectToMysqlContainer(t, env.TestMysql80Port, tools.MysqlVersion80)
defer container.DB.Close()
underscoreDbName := "test_db_name"
_, err := container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", underscoreDbName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName))
}()
underscoreDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
container.Username, container.Password, container.Host, container.Port, underscoreDbName)
underscoreDB, err := sqlx.Connect("mysql", underscoreDSN)
assert.NoError(t, err)
defer underscoreDB.Close()
_, err = underscoreDB.Exec(`
CREATE TABLE underscore_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)
`)
assert.NoError(t, err)
_, err = underscoreDB.Exec(`INSERT INTO underscore_test (data) VALUES ('test1')`)
assert.NoError(t, err)
underscoreUsername := fmt.Sprintf("under_%s", uuid.New().String()[:8])
underscorePassword := "underscorepass123"
_, err = underscoreDB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
underscoreUsername,
underscorePassword,
))
assert.NoError(t, err)
_, err = underscoreDB.Exec(fmt.Sprintf(
"GRANT SELECT, SHOW VIEW ON `%s`.* TO '%s'@'%%'",
underscoreDbName,
underscoreUsername,
))
assert.NoError(t, err)
_, err = underscoreDB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer func() {
_, _ = underscoreDB.Exec(fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", underscoreUsername))
}()
mysqlModel := &MysqlDatabase{
Version: tools.MysqlVersion80,
Host: container.Host,
Port: container.Port,
Username: underscoreUsername,
Password: underscorePassword,
Database: &underscoreDbName,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mysqlModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
}
type MysqlContainer struct {
Host string
Port int

View File

@@ -137,16 +137,13 @@ func (p *PostgresqlDatabase) EncryptSensitiveFields(
return nil
}
// PopulateVersionIfEmpty detects and sets the PostgreSQL version if not already set.
// PopulateDbData detects and sets the PostgreSQL version.
// This should be called before encrypting sensitive fields.
func (p *PostgresqlDatabase) PopulateVersionIfEmpty(
func (p *PostgresqlDatabase) PopulateDbData(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error {
if p.Version != "" {
return nil
}
return p.PopulateVersion(logger, encryptor, databaseID)
}
@@ -192,29 +189,33 @@ func (p *PostgresqlDatabase) PopulateVersion(
// IsUserReadOnly checks if the database user has read-only privileges.
//
// This method performs a comprehensive security check by examining:
// - Role-level attributes (superuser, createrole, createdb)
// - Role-level attributes (superuser, createrole, createdb, bypassrls, replication)
// - Database-level privileges (CREATE, TEMP)
// - Schema-level privileges (CREATE on any non-system schema)
// - Table-level write permissions (INSERT, UPDATE, DELETE, TRUNCATE, REFERENCES, TRIGGER)
// - Function-level privileges (EXECUTE on SECURITY DEFINER functions)
//
// A user is considered read-only only if they have ZERO write privileges
// across all three levels. This ensures the database user follows the
// across all levels. This ensures the database user follows the
// principle of least privilege for backup operations.
//
// Returns: (isReadOnly, detectedPrivileges, error)
func (p *PostgresqlDatabase) IsUserReadOnly(
ctx context.Context,
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) (bool, error) {
) (bool, []string, error) {
password, err := decryptPasswordIfNeeded(p.Password, encryptor, databaseID)
if err != nil {
return false, fmt.Errorf("failed to decrypt password: %w", err)
return false, nil, fmt.Errorf("failed to decrypt password: %w", err)
}
connStr := buildConnectionStringForDB(p, *p.Database, password)
conn, err := pgx.Connect(ctx, connStr)
if err != nil {
return false, fmt.Errorf("failed to connect to database: %w", err)
return false, nil, fmt.Errorf("failed to connect to database: %w", err)
}
defer func() {
if closeErr := conn.Close(ctx); closeErr != nil {
@@ -222,22 +223,38 @@ func (p *PostgresqlDatabase) IsUserReadOnly(
}
}()
var privileges []string
// LEVEL 1: Check role-level attributes
var isSuperuser, canCreateRole, canCreateDB bool
var isSuperuser, canCreateRole, canCreateDB, canBypassRLS, canReplication bool
err = conn.QueryRow(ctx, `
SELECT
rolsuper,
rolcreaterole,
rolcreatedb
rolcreatedb,
rolbypassrls,
rolreplication
FROM pg_roles
WHERE rolname = current_user
`).Scan(&isSuperuser, &canCreateRole, &canCreateDB)
`).Scan(&isSuperuser, &canCreateRole, &canCreateDB, &canBypassRLS, &canReplication)
if err != nil {
return false, fmt.Errorf("failed to check role attributes: %w", err)
return false, nil, fmt.Errorf("failed to check role attributes: %w", err)
}
if isSuperuser || canCreateRole || canCreateDB {
return false, nil
if isSuperuser {
privileges = append(privileges, "SUPERUSER")
}
if canCreateRole {
privileges = append(privileges, "CREATEROLE")
}
if canCreateDB {
privileges = append(privileges, "CREATEDB")
}
if canBypassRLS {
privileges = append(privileges, "BYPASSRLS")
}
if canReplication {
privileges = append(privileges, "REPLICATION")
}
// LEVEL 2: Check database-level privileges
@@ -248,46 +265,34 @@ func (p *PostgresqlDatabase) IsUserReadOnly(
has_database_privilege(current_user, current_database(), 'TEMP') as can_temp
`).Scan(&canCreate, &canTemp)
if err != nil {
return false, fmt.Errorf("failed to check database privileges: %w", err)
return false, nil, fmt.Errorf("failed to check database privileges: %w", err)
}
if canCreate || canTemp {
return false, nil
if canCreate {
privileges = append(privileges, "CREATE (database)")
}
if canTemp {
privileges = append(privileges, "TEMP")
}
// LEVEL 2.5: Check schema-level CREATE privileges
schemaRows, err := conn.Query(ctx, `
SELECT DISTINCT nspname
FROM pg_namespace n
WHERE has_schema_privilege(current_user, n.nspname, 'CREATE')
AND nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
`)
var hasSchemaCreate bool
err = conn.QueryRow(ctx, `
SELECT EXISTS(
SELECT 1
FROM pg_namespace n
WHERE has_schema_privilege(current_user, n.nspname, 'CREATE')
AND nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
)
`).Scan(&hasSchemaCreate)
if err != nil {
return false, fmt.Errorf("failed to check schema privileges: %w", err)
return false, nil, fmt.Errorf("failed to check schema privileges: %w", err)
}
defer schemaRows.Close()
// If user has CREATE privilege on any schema, they're not read-only
if schemaRows.Next() {
return false, nil
}
if err := schemaRows.Err(); err != nil {
return false, fmt.Errorf("error iterating schema privileges: %w", err)
if hasSchemaCreate {
privileges = append(privileges, "CREATE (schema)")
}
// LEVEL 3: Check table-level write permissions
rows, err := conn.Query(ctx, `
SELECT DISTINCT privilege_type
FROM information_schema.role_table_grants
WHERE grantee = current_user
AND table_schema NOT IN ('pg_catalog', 'information_schema')
`)
if err != nil {
return false, fmt.Errorf("failed to check table privileges: %w", err)
}
defer rows.Close()
writePrivileges := map[string]bool{
"INSERT": true,
"UPDATE": true,
@@ -297,22 +302,56 @@ func (p *PostgresqlDatabase) IsUserReadOnly(
"TRIGGER": true,
}
var tablePrivileges []string
rows, err := conn.Query(ctx, `
SELECT DISTINCT privilege_type
FROM information_schema.role_table_grants
WHERE grantee = current_user
AND table_schema NOT IN ('pg_catalog', 'information_schema')
`)
if err != nil {
return false, nil, fmt.Errorf("failed to check table privileges: %w", err)
}
for rows.Next() {
var privilege string
if err := rows.Scan(&privilege); err != nil {
return false, fmt.Errorf("failed to scan privilege: %w", err)
}
if writePrivileges[privilege] {
return false, nil
rows.Close()
return false, nil, fmt.Errorf("failed to scan privilege: %w", err)
}
tablePrivileges = append(tablePrivileges, privilege)
}
rows.Close()
if err := rows.Err(); err != nil {
return false, fmt.Errorf("error iterating privileges: %w", err)
return false, nil, fmt.Errorf("error iterating privileges: %w", err)
}
return true, nil
for _, privilege := range tablePrivileges {
if writePrivileges[privilege] {
privileges = append(privileges, privilege)
}
}
// LEVEL 4: Check for EXECUTE privilege on functions that are SECURITY DEFINER
var funcCount int
err = conn.QueryRow(ctx, `
SELECT COUNT(*)
FROM pg_proc p
JOIN pg_namespace n ON p.pronamespace = n.oid
WHERE n.nspname NOT IN ('pg_catalog', 'information_schema')
AND p.prosecdef = true
AND has_function_privilege(current_user, p.oid, 'EXECUTE')
`).Scan(&funcCount)
if err != nil {
return false, nil, fmt.Errorf("failed to check function privileges: %w", err)
}
if funcCount > 0 {
privileges = append(privileges, "EXECUTE (SECURITY DEFINER)")
}
isReadOnly := len(privileges) == 0
return isReadOnly, privileges, nil
}
// CreateReadOnlyUser creates a new PostgreSQL user with read-only privileges.
@@ -383,7 +422,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
}
}
newPassword := uuid.New().String()
newPassword := encryption.GenerateComplexPassword()
tx, err := conn.Begin(ctx)
if err != nil {
@@ -631,13 +670,9 @@ func testSingleDatabaseConnection(
}
postgresDb.Version = detectedVersion
// Test if we can perform basic operations (like pg_dump would need)
if err := testBasicOperations(ctx, conn, *postgresDb.Database); err != nil {
return fmt.Errorf(
"basic operations test failed for database '%s': %w",
*postgresDb.Database,
err,
)
// Verify user has sufficient permissions for backup operations
if err := checkBackupPermissions(ctx, conn, postgresDb.IncludeSchemas); err != nil {
return err
}
return nil
@@ -670,18 +705,128 @@ func detectDatabaseVersion(ctx context.Context, conn *pgx.Conn) (tools.Postgresq
}
}
// testBasicOperations tests basic operations that backup tools need
func testBasicOperations(ctx context.Context, conn *pgx.Conn, dbName string) error {
var hasCreatePriv bool
// checkBackupPermissions verifies the user has sufficient privileges for pg_dump backup.
// Required privileges: CONNECT on database, USAGE on schemas, SELECT on tables.
// If includeSchemas is specified, only checks permissions on those schemas.
func checkBackupPermissions(
ctx context.Context,
conn *pgx.Conn,
includeSchemas []string,
) error {
var missingPrivileges []string
// Check CONNECT privilege on database
var hasConnect bool
err := conn.QueryRow(ctx, "SELECT has_database_privilege(current_user, current_database(), 'CONNECT')").
Scan(&hasCreatePriv)
Scan(&hasConnect)
if err != nil {
return fmt.Errorf("cannot check database privileges: %w", err)
}
if !hasConnect {
missingPrivileges = append(missingPrivileges, "CONNECT on database")
}
if !hasCreatePriv {
return fmt.Errorf("user does not have CONNECT privilege on database '%s'", dbName)
// Check USAGE privilege on at least one non-system schema
var schemaCount int
if len(includeSchemas) > 0 {
// Check only the specified schemas
err = conn.QueryRow(ctx, `
SELECT COUNT(*)
FROM pg_namespace n
WHERE has_schema_privilege(current_user, n.nspname, 'USAGE')
AND n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
AND n.nspname NOT LIKE 'pg_temp_%'
AND n.nspname NOT LIKE 'pg_toast_temp_%'
AND n.nspname = ANY($1::text[])
`, includeSchemas).Scan(&schemaCount)
} else {
// Check all non-system schemas
err = conn.QueryRow(ctx, `
SELECT COUNT(*)
FROM pg_namespace n
WHERE has_schema_privilege(current_user, n.nspname, 'USAGE')
AND n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
AND n.nspname NOT LIKE 'pg_temp_%'
AND n.nspname NOT LIKE 'pg_toast_temp_%'
`).Scan(&schemaCount)
}
if err != nil {
return fmt.Errorf("cannot check schema privileges: %w", err)
}
if schemaCount == 0 {
missingPrivileges = append(missingPrivileges, "USAGE on at least one schema")
}
// Check SELECT privilege on at least one table (if tables exist)
// Use pg_tables from pg_catalog which shows all tables regardless of user privileges
var tableCount int
if len(includeSchemas) > 0 {
// Check only tables in the specified schemas
err = conn.QueryRow(ctx, `
SELECT COUNT(*)
FROM pg_catalog.pg_tables t
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
AND t.schemaname NOT LIKE 'pg_temp_%'
AND t.schemaname NOT LIKE 'pg_toast_temp_%'
AND t.schemaname = ANY($1::text[])
`, includeSchemas).Scan(&tableCount)
} else {
// Check all tables in non-system schemas
err = conn.QueryRow(ctx, `
SELECT COUNT(*)
FROM pg_catalog.pg_tables t
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
AND t.schemaname NOT LIKE 'pg_temp_%'
AND t.schemaname NOT LIKE 'pg_toast_temp_%'
`).Scan(&tableCount)
}
if err != nil {
return fmt.Errorf("cannot check table count: %w", err)
}
if tableCount > 0 {
// Check if user has SELECT on at least one of these tables
var selectableTableCount int
if len(includeSchemas) > 0 {
// Check only tables in the specified schemas
err = conn.QueryRow(ctx, `
SELECT COUNT(*)
FROM pg_catalog.pg_tables t
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
AND t.schemaname NOT LIKE 'pg_temp_%'
AND t.schemaname NOT LIKE 'pg_toast_temp_%'
AND t.schemaname = ANY($1::text[])
AND has_table_privilege(current_user, quote_ident(t.schemaname) || '.' || quote_ident(t.tablename), 'SELECT')
`, includeSchemas).Scan(&selectableTableCount)
} else {
// Check all tables in non-system schemas
err = conn.QueryRow(ctx, `
SELECT COUNT(*)
FROM pg_catalog.pg_tables t
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
AND t.schemaname NOT LIKE 'pg_temp_%'
AND t.schemaname NOT LIKE 'pg_toast_temp_%'
AND has_table_privilege(current_user, quote_ident(t.schemaname) || '.' || quote_ident(t.tablename), 'SELECT')
`).Scan(&selectableTableCount)
}
if err != nil {
return fmt.Errorf("cannot check SELECT privileges: %w", err)
}
if selectableTableCount == 0 {
missingPrivileges = append(missingPrivileges, "SELECT on tables")
}
}
if len(missingPrivileges) > 0 {
return fmt.Errorf(
"insufficient permissions for backup. Missing: %s. Required: CONNECT on database, USAGE on schemas, SELECT on tables",
strings.Join(missingPrivileges, ", "),
)
}
return nil

View File

@@ -1,16 +1,19 @@
package postgresql
import (
"context"
"fmt"
"log/slog"
"os"
"strconv"
"strings"
"testing"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"databasus-backend/internal/config"
"databasus-backend/internal/util/tools"
@@ -18,13 +21,23 @@ import (
func Test_TestConnection_PasswordContainingSpaces_TestedSuccessfully(t *testing.T) {
env := config.GetEnv()
container := connectToTestPostgresContainer(t, env.TestPostgres16Port)
container := connectToPostgresContainer(t, env.TestPostgres16Port)
defer container.DB.Close()
passwordWithSpaces := "test password with spaces"
usernameWithSpaces := fmt.Sprintf("testuser_spaces_%s", uuid.New().String()[:8])
_, err := container.DB.Exec(fmt.Sprintf(
_, err := container.DB.Exec(`
DROP TABLE IF EXISTS password_test CASCADE;
CREATE TABLE password_test (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO password_test (data) VALUES ('test1');
`)
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
usernameWithSpaces,
passwordWithSpaces,
@@ -38,6 +51,18 @@ func Test_TestConnection_PasswordContainingSpaces_TestedSuccessfully(t *testing.
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT USAGE ON SCHEMA public TO "%s"`,
usernameWithSpaces,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT SELECT ON ALL TABLES IN SCHEMA public TO "%s"`,
usernameWithSpaces,
))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, usernameWithSpaces))
}()
@@ -59,7 +84,628 @@ func Test_TestConnection_PasswordContainingSpaces_TestedSuccessfully(t *testing.
assert.NoError(t, err)
}
type testPostgresContainer struct {
func Test_TestConnection_InsufficientPermissions_ReturnsError(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version string
port string
}{
{"PostgreSQL 12", "12", env.TestPostgres12Port},
{"PostgreSQL 13", "13", env.TestPostgres13Port},
{"PostgreSQL 14", "14", env.TestPostgres14Port},
{"PostgreSQL 15", "15", env.TestPostgres15Port},
{"PostgreSQL 16", "16", env.TestPostgres16Port},
{"PostgreSQL 17", "17", env.TestPostgres17Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToPostgresContainer(t, tc.port)
defer container.DB.Close()
_, err := container.DB.Exec(`
DROP TABLE IF EXISTS permission_test CASCADE;
CREATE TABLE permission_test (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO permission_test (data) VALUES ('test1');
`)
assert.NoError(t, err)
limitedUsername := fmt.Sprintf("limited_user_%s", uuid.New().String()[:8])
limitedPassword := "limitedpassword123"
_, err = container.DB.Exec(fmt.Sprintf(
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
limitedUsername,
limitedPassword,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
container.Database,
limitedUsername,
))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, limitedUsername))
}()
pgModel := &PostgresqlDatabase{
Version: tools.GetPostgresqlVersionEnum(tc.version),
Host: container.Host,
Port: container.Port,
Username: limitedUsername,
Password: limitedPassword,
Database: &container.Database,
IsHttps: false,
CpuCount: 1,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = pgModel.TestConnection(logger, nil, uuid.New())
assert.Error(t, err)
if err != nil {
assert.Contains(t, err.Error(), "insufficient permissions")
}
})
}
}
func Test_TestConnection_SufficientPermissions_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version string
port string
}{
{"PostgreSQL 12", "12", env.TestPostgres12Port},
{"PostgreSQL 13", "13", env.TestPostgres13Port},
{"PostgreSQL 14", "14", env.TestPostgres14Port},
{"PostgreSQL 15", "15", env.TestPostgres15Port},
{"PostgreSQL 16", "16", env.TestPostgres16Port},
{"PostgreSQL 17", "17", env.TestPostgres17Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToPostgresContainer(t, tc.port)
defer container.DB.Close()
_, err := container.DB.Exec(`
DROP TABLE IF EXISTS backup_test CASCADE;
CREATE TABLE backup_test (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO backup_test (data) VALUES ('test1');
`)
assert.NoError(t, err)
backupUsername := fmt.Sprintf("backup_user_%s", uuid.New().String()[:8])
backupPassword := "backuppassword123"
_, err = container.DB.Exec(fmt.Sprintf(
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
backupUsername,
backupPassword,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
container.Database,
backupUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT USAGE ON SCHEMA public TO "%s"`,
backupUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT SELECT ON ALL TABLES IN SCHEMA public TO "%s"`,
backupUsername,
))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, backupUsername))
}()
pgModel := &PostgresqlDatabase{
Version: tools.GetPostgresqlVersionEnum(tc.version),
Host: container.Host,
Port: container.Port,
Username: backupUsername,
Password: backupPassword,
Database: &container.Database,
IsHttps: false,
CpuCount: 1,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = pgModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
})
}
}
func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version string
port string
}{
{"PostgreSQL 12", "12", env.TestPostgres12Port},
{"PostgreSQL 13", "13", env.TestPostgres13Port},
{"PostgreSQL 14", "14", env.TestPostgres14Port},
{"PostgreSQL 15", "15", env.TestPostgres15Port},
{"PostgreSQL 16", "16", env.TestPostgres16Port},
{"PostgreSQL 17", "17", env.TestPostgres17Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToPostgresContainer(t, tc.port)
defer container.DB.Close()
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
isReadOnly, privileges, err := pgModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.False(t, isReadOnly, "Admin user should not be read-only")
assert.NotEmpty(t, privileges, "Admin user should have privileges")
})
}
}
func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) {
env := config.GetEnv()
container := connectToPostgresContainer(t, env.TestPostgres16Port)
defer container.DB.Close()
_, err := container.DB.Exec(`
DROP TABLE IF EXISTS readonly_check_test CASCADE;
CREATE TABLE readonly_check_test (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO readonly_check_test (data) VALUES ('test1');
`)
assert.NoError(t, err)
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
readOnlyModel := &PostgresqlDatabase{
Version: pgModel.Version,
Host: pgModel.Host,
Port: pgModel.Port,
Username: username,
Password: password,
Database: pgModel.Database,
IsHttps: false,
CpuCount: 1,
}
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.True(t, isReadOnly, "Read-only user should be read-only")
assert.Empty(t, privileges, "Read-only user should have no write privileges")
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
if err != nil {
t.Logf("Warning: Failed to drop owned objects: %v", err)
}
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
assert.NoError(t, err)
}
func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version string
port string
}{
{"PostgreSQL 12", "12", env.TestPostgres12Port},
{"PostgreSQL 13", "13", env.TestPostgres13Port},
{"PostgreSQL 14", "14", env.TestPostgres14Port},
{"PostgreSQL 15", "15", env.TestPostgres15Port},
{"PostgreSQL 16", "16", env.TestPostgres16Port},
{"PostgreSQL 17", "17", env.TestPostgres17Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToPostgresContainer(t, tc.port)
defer container.DB.Close()
_, err := container.DB.Exec(`
DROP TABLE IF EXISTS readonly_test CASCADE;
DROP TABLE IF EXISTS hack_table CASCADE;
DROP TABLE IF EXISTS future_table CASCADE;
CREATE TABLE readonly_test (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO readonly_test (data) VALUES ('test1'), ('test2');
`)
assert.NoError(t, err)
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.NotEmpty(t, username)
assert.NotEmpty(t, password)
assert.True(t, strings.HasPrefix(username, "databasus-"))
readOnlyModel := &PostgresqlDatabase{
Version: pgModel.Version,
Host: pgModel.Host,
Port: pgModel.Port,
Username: username,
Password: password,
Database: pgModel.Database,
IsHttps: false,
}
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
ctx,
logger,
nil,
uuid.New(),
)
assert.NoError(t, err)
assert.True(t, isReadOnly, "Created user should be read-only")
assert.Empty(t, privileges, "Read-only user should have no write privileges")
readOnlyDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
username,
password,
container.Database,
)
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
var count int
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM readonly_test")
assert.NoError(t, err)
assert.Equal(t, 2, count)
_, err = readOnlyConn.Exec("INSERT INTO readonly_test (data) VALUES ('should-fail')")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = readOnlyConn.Exec("UPDATE readonly_test SET data = 'hacked' WHERE id = 1")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = readOnlyConn.Exec("DELETE FROM readonly_test WHERE id = 1")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = readOnlyConn.Exec("CREATE TABLE hack_table (id INT)")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
if err != nil {
t.Logf("Warning: Failed to drop owned objects: %v", err)
}
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
assert.NoError(t, err)
})
}
}
func Test_ReadOnlyUser_FutureTables_HaveSelectPermission(t *testing.T) {
env := config.GetEnv()
container := connectToPostgresContainer(t, env.TestPostgres16Port)
defer container.DB.Close()
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
_, err = container.DB.Exec(`
CREATE TABLE future_table (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO future_table (data) VALUES ('future_data');
`)
assert.NoError(t, err)
readOnlyDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host, container.Port, username, password, container.Database)
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
var data string
err = readOnlyConn.Get(&data, "SELECT data FROM future_table LIMIT 1")
assert.NoError(t, err)
assert.Equal(t, "future_data", data)
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
if err != nil {
t.Logf("Warning: Failed to drop owned objects: %v", err)
}
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
assert.NoError(t, err)
}
func Test_ReadOnlyUser_MultipleSchemas_AllAccessible(t *testing.T) {
env := config.GetEnv()
container := connectToPostgresContainer(t, env.TestPostgres16Port)
defer container.DB.Close()
_, err := container.DB.Exec(`
DROP SCHEMA IF EXISTS schema_a CASCADE;
DROP SCHEMA IF EXISTS schema_b CASCADE;
CREATE SCHEMA schema_a;
CREATE SCHEMA schema_b;
CREATE TABLE schema_a.table_a (id INT, data TEXT);
CREATE TABLE schema_b.table_b (id INT, data TEXT);
INSERT INTO schema_a.table_a VALUES (1, 'data_a');
INSERT INTO schema_b.table_b VALUES (2, 'data_b');
`)
assert.NoError(t, err)
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
readOnlyDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host, container.Port, username, password, container.Database)
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
var dataA string
err = readOnlyConn.Get(&dataA, "SELECT data FROM schema_a.table_a LIMIT 1")
assert.NoError(t, err)
assert.Equal(t, "data_a", dataA)
var dataB string
err = readOnlyConn.Get(&dataB, "SELECT data FROM schema_b.table_b LIMIT 1")
assert.NoError(t, err)
assert.Equal(t, "data_b", dataB)
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
if err != nil {
t.Logf("Warning: Failed to drop owned objects: %v", err)
}
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
assert.NoError(t, err)
_, err = container.DB.Exec(`DROP SCHEMA schema_a CASCADE; DROP SCHEMA schema_b CASCADE;`)
assert.NoError(t, err)
}
func Test_CreateReadOnlyUser_DatabaseNameWithDash_Success(t *testing.T) {
env := config.GetEnv()
container := connectToPostgresContainer(t, env.TestPostgres16Port)
defer container.DB.Close()
dashDbName := "test-db-with-dash"
_, err := container.DB.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS "%s"`, dashDbName))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(`CREATE DATABASE "%s"`, dashDbName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS "%s"`, dashDbName))
}()
dashDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host, container.Port, container.Username, container.Password, dashDbName)
dashDB, err := sqlx.Connect("postgres", dashDSN)
assert.NoError(t, err)
defer dashDB.Close()
_, err = dashDB.Exec(`
CREATE TABLE dash_test (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO dash_test (data) VALUES ('test1'), ('test2');
`)
assert.NoError(t, err)
pgModel := &PostgresqlDatabase{
Version: tools.GetPostgresqlVersionEnum("16"),
Host: container.Host,
Port: container.Port,
Username: container.Username,
Password: container.Password,
Database: &dashDbName,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.NotEmpty(t, username)
assert.NotEmpty(t, password)
assert.True(t, strings.HasPrefix(username, "databasus-"))
readOnlyDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host, container.Port, username, password, dashDbName)
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
var count int
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM dash_test")
assert.NoError(t, err)
assert.Equal(t, 2, count)
_, err = readOnlyConn.Exec("INSERT INTO dash_test (data) VALUES ('should-fail')")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = dashDB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
if err != nil {
t.Logf("Warning: Failed to drop owned objects: %v", err)
}
_, err = dashDB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
assert.NoError(t, err)
}
func Test_CreateReadOnlyUser_Supabase_UserCanReadButNotWrite(t *testing.T) {
env := config.GetEnv()
if env.TestSupabaseHost == "" {
t.Skip("Skipping Supabase test: missing environment variables")
}
portInt, err := strconv.Atoi(env.TestSupabasePort)
assert.NoError(t, err)
dsn := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=require",
env.TestSupabaseHost,
portInt,
env.TestSupabaseUsername,
env.TestSupabasePassword,
env.TestSupabaseDatabase,
)
adminDB, err := sqlx.Connect("postgres", dsn)
require.NoError(t, err)
defer adminDB.Close()
tableName := fmt.Sprintf(
"readonly_test_%s",
strings.ReplaceAll(uuid.New().String()[:8], "-", ""),
)
_, err = adminDB.Exec(fmt.Sprintf(`
DROP TABLE IF EXISTS public.%s CASCADE;
CREATE TABLE public.%s (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO public.%s (data) VALUES ('test1'), ('test2');
`, tableName, tableName, tableName))
assert.NoError(t, err)
defer func() {
_, _ = adminDB.Exec(fmt.Sprintf(`DROP TABLE IF EXISTS public.%s CASCADE`, tableName))
}()
pgModel := &PostgresqlDatabase{
Host: env.TestSupabaseHost,
Port: portInt,
Username: env.TestSupabaseUsername,
Password: env.TestSupabasePassword,
Database: &env.TestSupabaseDatabase,
IsHttps: true,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
connectionUsername, newPassword, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.NotEmpty(t, connectionUsername)
assert.NotEmpty(t, newPassword)
assert.True(t, strings.HasPrefix(connectionUsername, "databasus-"))
baseUsername := connectionUsername
if idx := strings.Index(connectionUsername, "."); idx != -1 {
baseUsername = connectionUsername[:idx]
}
defer func() {
_, _ = adminDB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, baseUsername))
_, _ = adminDB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, baseUsername))
}()
readOnlyDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=require",
env.TestSupabaseHost,
portInt,
connectionUsername,
newPassword,
env.TestSupabaseDatabase,
)
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
var count int
err = readOnlyConn.Get(&count, fmt.Sprintf("SELECT COUNT(*) FROM public.%s", tableName))
assert.NoError(t, err)
assert.Equal(t, 2, count)
_, err = readOnlyConn.Exec(
fmt.Sprintf("INSERT INTO public.%s (data) VALUES ('should-fail')", tableName),
)
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = readOnlyConn.Exec(
fmt.Sprintf("UPDATE public.%s SET data = 'hacked' WHERE id = 1", tableName),
)
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = readOnlyConn.Exec(fmt.Sprintf("DELETE FROM public.%s WHERE id = 1", tableName))
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = readOnlyConn.Exec("CREATE TABLE public.hack_table (id INT)")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
}
type PostgresContainer struct {
Host string
Port int
Username string
@@ -68,7 +714,7 @@ type testPostgresContainer struct {
DB *sqlx.DB
}
func connectToTestPostgresContainer(t *testing.T, port string) *testPostgresContainer {
func connectToPostgresContainer(t *testing.T, port string) *PostgresContainer {
dbName := "testdb"
password := "testpassword"
username := "testuser"
@@ -83,7 +729,11 @@ func connectToTestPostgresContainer(t *testing.T, port string) *testPostgresCont
db, err := sqlx.Connect("postgres", dsn)
assert.NoError(t, err)
return &testPostgresContainer{
var versionStr string
err = db.Get(&versionStr, "SELECT version()")
assert.NoError(t, err)
return &PostgresContainer{
Host: host,
Port: portInt,
Username: username,
@@ -92,3 +742,42 @@ func connectToTestPostgresContainer(t *testing.T, port string) *testPostgresCont
DB: db,
}
}
func createPostgresModel(container *PostgresContainer) *PostgresqlDatabase {
var versionStr string
err := container.DB.Get(&versionStr, "SELECT version()")
if err != nil {
return nil
}
version := extractPostgresVersion(versionStr)
return &PostgresqlDatabase{
Version: version,
Host: container.Host,
Port: container.Port,
Username: container.Username,
Password: container.Password,
Database: &container.Database,
IsHttps: false,
CpuCount: 1,
}
}
func extractPostgresVersion(versionStr string) tools.PostgresqlVersion {
if strings.Contains(versionStr, "PostgreSQL 12") {
return tools.GetPostgresqlVersionEnum("12")
} else if strings.Contains(versionStr, "PostgreSQL 13") {
return tools.GetPostgresqlVersionEnum("13")
} else if strings.Contains(versionStr, "PostgreSQL 14") {
return tools.GetPostgresqlVersionEnum("14")
} else if strings.Contains(versionStr, "PostgreSQL 15") {
return tools.GetPostgresqlVersionEnum("15")
} else if strings.Contains(versionStr, "PostgreSQL 16") {
return tools.GetPostgresqlVersionEnum("16")
} else if strings.Contains(versionStr, "PostgreSQL 17") {
return tools.GetPostgresqlVersionEnum("17")
}
return tools.GetPostgresqlVersionEnum("16")
}

View File

@@ -1,508 +0,0 @@
package postgresql
import (
"context"
"fmt"
"log/slog"
"os"
"strconv"
"strings"
"testing"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"databasus-backend/internal/config"
"databasus-backend/internal/util/tools"
)
func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version string
port string
}{
{"PostgreSQL 12", "12", env.TestPostgres12Port},
{"PostgreSQL 13", "13", env.TestPostgres13Port},
{"PostgreSQL 14", "14", env.TestPostgres14Port},
{"PostgreSQL 15", "15", env.TestPostgres15Port},
{"PostgreSQL 16", "16", env.TestPostgres16Port},
{"PostgreSQL 17", "17", env.TestPostgres17Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToPostgresContainer(t, tc.port)
defer container.DB.Close()
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
isReadOnly, err := pgModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.False(t, isReadOnly, "Admin user should not be read-only")
})
}
}
func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version string
port string
}{
{"PostgreSQL 12", "12", env.TestPostgres12Port},
{"PostgreSQL 13", "13", env.TestPostgres13Port},
{"PostgreSQL 14", "14", env.TestPostgres14Port},
{"PostgreSQL 15", "15", env.TestPostgres15Port},
{"PostgreSQL 16", "16", env.TestPostgres16Port},
{"PostgreSQL 17", "17", env.TestPostgres17Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToPostgresContainer(t, tc.port)
defer container.DB.Close()
_, err := container.DB.Exec(`
DROP TABLE IF EXISTS readonly_test CASCADE;
DROP TABLE IF EXISTS hack_table CASCADE;
DROP TABLE IF EXISTS future_table CASCADE;
CREATE TABLE readonly_test (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO readonly_test (data) VALUES ('test1'), ('test2');
`)
assert.NoError(t, err)
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.NotEmpty(t, username)
assert.NotEmpty(t, password)
assert.True(t, strings.HasPrefix(username, "databasus-"))
readOnlyModel := &PostgresqlDatabase{
Version: pgModel.Version,
Host: pgModel.Host,
Port: pgModel.Port,
Username: username,
Password: password,
Database: pgModel.Database,
IsHttps: false,
}
isReadOnly, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.True(t, isReadOnly, "Created user should be read-only")
readOnlyDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
username,
password,
container.Database,
)
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
var count int
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM readonly_test")
assert.NoError(t, err)
assert.Equal(t, 2, count)
_, err = readOnlyConn.Exec("INSERT INTO readonly_test (data) VALUES ('should-fail')")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = readOnlyConn.Exec("UPDATE readonly_test SET data = 'hacked' WHERE id = 1")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = readOnlyConn.Exec("DELETE FROM readonly_test WHERE id = 1")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = readOnlyConn.Exec("CREATE TABLE hack_table (id INT)")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
// Clean up: Drop user with CASCADE to handle default privilege dependencies
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
if err != nil {
t.Logf("Warning: Failed to drop owned objects: %v", err)
}
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
assert.NoError(t, err)
})
}
}
func Test_ReadOnlyUser_FutureTables_HaveSelectPermission(t *testing.T) {
env := config.GetEnv()
container := connectToPostgresContainer(t, env.TestPostgres16Port)
defer container.DB.Close()
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
_, err = container.DB.Exec(`
CREATE TABLE future_table (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO future_table (data) VALUES ('future_data');
`)
assert.NoError(t, err)
readOnlyDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host, container.Port, username, password, container.Database)
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
var data string
err = readOnlyConn.Get(&data, "SELECT data FROM future_table LIMIT 1")
assert.NoError(t, err)
assert.Equal(t, "future_data", data)
// Clean up: Drop user with CASCADE to handle default privilege dependencies
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
if err != nil {
t.Logf("Warning: Failed to drop owned objects: %v", err)
}
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
assert.NoError(t, err)
}
func Test_ReadOnlyUser_MultipleSchemas_AllAccessible(t *testing.T) {
env := config.GetEnv()
container := connectToPostgresContainer(t, env.TestPostgres16Port)
defer container.DB.Close()
_, err := container.DB.Exec(`
DROP SCHEMA IF EXISTS schema_a CASCADE;
DROP SCHEMA IF EXISTS schema_b CASCADE;
CREATE SCHEMA schema_a;
CREATE SCHEMA schema_b;
CREATE TABLE schema_a.table_a (id INT, data TEXT);
CREATE TABLE schema_b.table_b (id INT, data TEXT);
INSERT INTO schema_a.table_a VALUES (1, 'data_a');
INSERT INTO schema_b.table_b VALUES (2, 'data_b');
`)
assert.NoError(t, err)
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
readOnlyDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host, container.Port, username, password, container.Database)
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
var dataA string
err = readOnlyConn.Get(&dataA, "SELECT data FROM schema_a.table_a LIMIT 1")
assert.NoError(t, err)
assert.Equal(t, "data_a", dataA)
var dataB string
err = readOnlyConn.Get(&dataB, "SELECT data FROM schema_b.table_b LIMIT 1")
assert.NoError(t, err)
assert.Equal(t, "data_b", dataB)
// Clean up: Drop user with CASCADE to handle default privilege dependencies
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
if err != nil {
t.Logf("Warning: Failed to drop owned objects: %v", err)
}
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
assert.NoError(t, err)
_, err = container.DB.Exec(`DROP SCHEMA schema_a CASCADE; DROP SCHEMA schema_b CASCADE;`)
assert.NoError(t, err)
}
func Test_CreateReadOnlyUser_DatabaseNameWithDash_Success(t *testing.T) {
env := config.GetEnv()
container := connectToPostgresContainer(t, env.TestPostgres16Port)
defer container.DB.Close()
dashDbName := "test-db-with-dash"
_, err := container.DB.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS "%s"`, dashDbName))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(`CREATE DATABASE "%s"`, dashDbName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS "%s"`, dashDbName))
}()
dashDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host, container.Port, container.Username, container.Password, dashDbName)
dashDB, err := sqlx.Connect("postgres", dashDSN)
assert.NoError(t, err)
defer dashDB.Close()
_, err = dashDB.Exec(`
CREATE TABLE dash_test (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO dash_test (data) VALUES ('test1'), ('test2');
`)
assert.NoError(t, err)
pgModel := &PostgresqlDatabase{
Version: tools.GetPostgresqlVersionEnum("16"),
Host: container.Host,
Port: container.Port,
Username: container.Username,
Password: container.Password,
Database: &dashDbName,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.NotEmpty(t, username)
assert.NotEmpty(t, password)
assert.True(t, strings.HasPrefix(username, "databasus-"))
readOnlyDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host, container.Port, username, password, dashDbName)
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
var count int
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM dash_test")
assert.NoError(t, err)
assert.Equal(t, 2, count)
_, err = readOnlyConn.Exec("INSERT INTO dash_test (data) VALUES ('should-fail')")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = dashDB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
if err != nil {
t.Logf("Warning: Failed to drop owned objects: %v", err)
}
_, err = dashDB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
assert.NoError(t, err)
}
func Test_CreateReadOnlyUser_Supabase_UserCanReadButNotWrite(t *testing.T) {
env := config.GetEnv()
if env.TestSupabaseHost == "" {
t.Skip("Skipping Supabase test: missing environment variables")
}
portInt, err := strconv.Atoi(env.TestSupabasePort)
assert.NoError(t, err)
dsn := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=require",
env.TestSupabaseHost,
portInt,
env.TestSupabaseUsername,
env.TestSupabasePassword,
env.TestSupabaseDatabase,
)
adminDB, err := sqlx.Connect("postgres", dsn)
require.NoError(t, err)
defer adminDB.Close()
tableName := fmt.Sprintf(
"readonly_test_%s",
strings.ReplaceAll(uuid.New().String()[:8], "-", ""),
)
_, err = adminDB.Exec(fmt.Sprintf(`
DROP TABLE IF EXISTS public.%s CASCADE;
CREATE TABLE public.%s (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO public.%s (data) VALUES ('test1'), ('test2');
`, tableName, tableName, tableName))
assert.NoError(t, err)
defer func() {
_, _ = adminDB.Exec(fmt.Sprintf(`DROP TABLE IF EXISTS public.%s CASCADE`, tableName))
}()
pgModel := &PostgresqlDatabase{
Host: env.TestSupabaseHost,
Port: portInt,
Username: env.TestSupabaseUsername,
Password: env.TestSupabasePassword,
Database: &env.TestSupabaseDatabase,
IsHttps: true,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
connectionUsername, newPassword, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.NotEmpty(t, connectionUsername)
assert.NotEmpty(t, newPassword)
assert.True(t, strings.HasPrefix(connectionUsername, "databasus-"))
baseUsername := connectionUsername
if idx := strings.Index(connectionUsername, "."); idx != -1 {
baseUsername = connectionUsername[:idx]
}
defer func() {
_, _ = adminDB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, baseUsername))
_, _ = adminDB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, baseUsername))
}()
readOnlyDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=require",
env.TestSupabaseHost,
portInt,
connectionUsername,
newPassword,
env.TestSupabaseDatabase,
)
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
var count int
err = readOnlyConn.Get(&count, fmt.Sprintf("SELECT COUNT(*) FROM public.%s", tableName))
assert.NoError(t, err)
assert.Equal(t, 2, count)
_, err = readOnlyConn.Exec(
fmt.Sprintf("INSERT INTO public.%s (data) VALUES ('should-fail')", tableName),
)
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = readOnlyConn.Exec(
fmt.Sprintf("UPDATE public.%s SET data = 'hacked' WHERE id = 1", tableName),
)
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = readOnlyConn.Exec(fmt.Sprintf("DELETE FROM public.%s WHERE id = 1", tableName))
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = readOnlyConn.Exec("CREATE TABLE public.hack_table (id INT)")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
}
type PostgresContainer struct {
Host string
Port int
Username string
Password string
Database string
DB *sqlx.DB
}
func connectToPostgresContainer(t *testing.T, port string) *PostgresContainer {
dbName := "testdb"
password := "testpassword"
username := "testuser"
host := "localhost"
portInt, err := strconv.Atoi(port)
assert.NoError(t, err)
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
host, portInt, username, password, dbName)
db, err := sqlx.Connect("postgres", dsn)
assert.NoError(t, err)
var versionStr string
err = db.Get(&versionStr, "SELECT version()")
assert.NoError(t, err)
return &PostgresContainer{
Host: host,
Port: portInt,
Username: username,
Password: password,
Database: dbName,
DB: db,
}
}
func createPostgresModel(container *PostgresContainer) *PostgresqlDatabase {
var versionStr string
err := container.DB.Get(&versionStr, "SELECT version()")
if err != nil {
return nil
}
version := extractPostgresVersion(versionStr)
return &PostgresqlDatabase{
Version: version,
Host: container.Host,
Port: container.Port,
Username: container.Username,
Password: container.Password,
Database: &container.Database,
IsHttps: false,
}
}
func extractPostgresVersion(versionStr string) tools.PostgresqlVersion {
if strings.Contains(versionStr, "PostgreSQL 12") {
return tools.GetPostgresqlVersionEnum("12")
} else if strings.Contains(versionStr, "PostgreSQL 13") {
return tools.GetPostgresqlVersionEnum("13")
} else if strings.Contains(versionStr, "PostgreSQL 14") {
return tools.GetPostgresqlVersionEnum("14")
} else if strings.Contains(versionStr, "PostgreSQL 15") {
return tools.GetPostgresqlVersionEnum("15")
} else if strings.Contains(versionStr, "PostgreSQL 16") {
return tools.GetPostgresqlVersionEnum("16")
} else if strings.Contains(versionStr, "PostgreSQL 17") {
return tools.GetPostgresqlVersionEnum("17")
}
return tools.GetPostgresqlVersionEnum("16")
}

View File

@@ -6,5 +6,6 @@ type CreateReadOnlyUserResponse struct {
}
type IsReadOnlyResponse struct {
IsReadOnly bool `json:"isReadOnly"`
IsReadOnly bool `json:"isReadOnly"`
Privileges []string `json:"privileges"`
}

View File

@@ -104,21 +104,21 @@ func (d *Database) EncryptSensitiveFields(encryptor encryption.FieldEncryptor) e
return nil
}
func (d *Database) PopulateVersionIfEmpty(
func (d *Database) PopulateDbData(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
) error {
if d.Postgresql != nil {
return d.Postgresql.PopulateVersionIfEmpty(logger, encryptor, d.ID)
return d.Postgresql.PopulateDbData(logger, encryptor, d.ID)
}
if d.Mysql != nil {
return d.Mysql.PopulateVersionIfEmpty(logger, encryptor, d.ID)
return d.Mysql.PopulateDbData(logger, encryptor, d.ID)
}
if d.Mariadb != nil {
return d.Mariadb.PopulateVersionIfEmpty(logger, encryptor, d.ID)
return d.Mariadb.PopulateDbData(logger, encryptor, d.ID)
}
if d.Mongodb != nil {
return d.Mongodb.PopulateVersionIfEmpty(logger, encryptor, d.ID)
return d.Mongodb.PopulateDbData(logger, encryptor, d.ID)
}
return nil
}

View File

@@ -82,8 +82,8 @@ func (s *DatabaseService) CreateDatabase(
return nil, err
}
if err := database.PopulateVersionIfEmpty(s.logger, s.fieldEncryptor); err != nil {
return nil, fmt.Errorf("failed to auto-detect database version: %w", err)
if err := database.PopulateDbData(s.logger, s.fieldEncryptor); err != nil {
return nil, fmt.Errorf("failed to auto-detect database data: %w", err)
}
if err := database.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
@@ -149,8 +149,8 @@ func (s *DatabaseService) UpdateDatabase(
return err
}
if err := existingDatabase.PopulateVersionIfEmpty(s.logger, s.fieldEncryptor); err != nil {
return fmt.Errorf("failed to auto-detect database version: %w", err)
if err := existingDatabase.PopulateDbData(s.logger, s.fieldEncryptor); err != nil {
return fmt.Errorf("failed to auto-detect database data: %w", err)
}
if err := existingDatabase.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
@@ -594,17 +594,17 @@ func (s *DatabaseService) OnBeforeWorkspaceDeletion(workspaceID uuid.UUID) error
func (s *DatabaseService) IsUserReadOnly(
user *users_models.User,
database *Database,
) (bool, error) {
) (bool, []string, error) {
var usingDatabase *Database
if database.ID != uuid.Nil {
existingDatabase, err := s.dbRepository.FindByID(database.ID)
if err != nil {
return false, err
return false, nil, err
}
if existingDatabase.WorkspaceID == nil {
return false, errors.New("cannot check user for database without workspace")
return false, nil, errors.New("cannot check user for database without workspace")
}
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(
@@ -612,31 +612,34 @@ func (s *DatabaseService) IsUserReadOnly(
user,
)
if err != nil {
return false, err
return false, nil, err
}
if !canAccess {
return false, errors.New("insufficient permissions to access this database")
return false, nil, errors.New("insufficient permissions to access this database")
}
if database.WorkspaceID != nil && *existingDatabase.WorkspaceID != *database.WorkspaceID {
return false, errors.New("database does not belong to this workspace")
return false, nil, errors.New("database does not belong to this workspace")
}
existingDatabase.Update(database)
if err := existingDatabase.Validate(); err != nil {
return false, err
return false, nil, err
}
usingDatabase = existingDatabase
} else {
if database.WorkspaceID != nil {
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, user)
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(
*database.WorkspaceID,
user,
)
if err != nil {
return false, err
return false, nil, err
}
if !canAccess {
return false, errors.New("insufficient permissions to access this workspace")
return false, nil, errors.New("insufficient permissions to access this workspace")
}
}
@@ -676,7 +679,7 @@ func (s *DatabaseService) IsUserReadOnly(
usingDatabase.ID,
)
default:
return false, errors.New("read-only check not supported for this database type")
return false, nil, errors.New("read-only check not supported for this database type")
}
}

View File

@@ -1,6 +1,12 @@
package databases
import (
"fmt"
"strconv"
"databasus-backend/internal/config"
"databasus-backend/internal/features/databases/databases/mariadb"
"databasus-backend/internal/features/databases/databases/mongodb"
"databasus-backend/internal/features/databases/databases/postgresql"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
@@ -9,6 +15,71 @@ import (
"github.com/google/uuid"
)
func GetTestPostgresConfig() *postgresql.PostgresqlDatabase {
env := config.GetEnv()
port, err := strconv.Atoi(env.TestPostgres16Port)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
}
testDbName := "testdb"
return &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
CpuCount: 1,
}
}
func GetTestMariadbConfig() *mariadb.MariadbDatabase {
env := config.GetEnv()
portStr := env.TestMariadb1011Port
if portStr == "" {
portStr = "33111"
}
port, err := strconv.Atoi(portStr)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_MARIADB_1011_PORT: %v", err))
}
testDbName := "testdb"
return &mariadb.MariadbDatabase{
Version: tools.MariadbVersion1011,
Host: "localhost",
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
}
}
func GetTestMongodbConfig() *mongodb.MongodbDatabase {
env := config.GetEnv()
portStr := env.TestMongodb70Port
if portStr == "" {
portStr = "27070"
}
port, err := strconv.Atoi(portStr)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_MONGODB_70_PORT: %v", err))
}
return &mongodb.MongodbDatabase{
Version: tools.MongodbVersion7,
Host: "localhost",
Port: port,
Username: "root",
Password: "rootpassword",
Database: "testdb",
AuthDatabase: "admin",
IsHttps: false,
CpuCount: 1,
}
}
func CreateTestDatabase(
workspaceID uuid.UUID,
storage *storages.Storage,
@@ -18,16 +89,7 @@ func CreateTestDatabase(
WorkspaceID: &workspaceID,
Name: "test " + uuid.New().String(),
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
CpuCount: 1,
},
Postgresql: GetTestPostgresConfig(),
Notifiers: []notifiers.Notifier{
*notifier,
},

View File

@@ -1,7 +1,7 @@
package healthcheck_attempt
import (
"databasus-backend/internal/config"
"context"
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
"log/slog"
"time"
@@ -13,18 +13,19 @@ type HealthcheckAttemptBackgroundService struct {
logger *slog.Logger
}
func (s *HealthcheckAttemptBackgroundService) Run() {
func (s *HealthcheckAttemptBackgroundService) Run(ctx context.Context) {
// first healthcheck immediately
s.checkDatabases()
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for range ticker.C {
if config.IsShouldShutdown() {
break
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.checkDatabases()
}
s.checkDatabases()
}
}

View File

@@ -12,13 +12,11 @@ import (
"github.com/stretchr/testify/assert"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/postgresql"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
test_utils "databasus-backend/internal/util/testing"
"databasus-backend/internal/util/tools"
)
func createTestRouter() *gin.Engine {
@@ -111,7 +109,13 @@ func Test_GetAttemptsByDatabase_PermissionsEnforced(t *testing.T) {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
*tt.workspaceRole,
owner.Token,
router,
)
testUserToken = member.Token
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
@@ -205,20 +209,11 @@ func createTestDatabaseViaAPI(
token string,
router *gin.Engine,
) *databases.Database {
testDbName := "test_db"
request := databases.Database{
WorkspaceID: &workspaceID,
Name: name,
Type: databases.DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Database: &testDbName,
CpuCount: 1,
},
Postgresql: databases.GetTestPostgresConfig(),
}
w := workspaces_testing.MakeAPIRequest(

View File

@@ -10,13 +10,11 @@ import (
"github.com/stretchr/testify/assert"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/postgresql"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
test_utils "databasus-backend/internal/util/testing"
"databasus-backend/internal/util/tools"
)
func createTestRouter() *gin.Engine {
@@ -90,7 +88,13 @@ func Test_SaveHealthcheckConfig_PermissionsEnforced(t *testing.T) {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
*tt.workspaceRole,
owner.Token,
router,
)
testUserToken = member.Token
}
@@ -228,7 +232,13 @@ func Test_GetHealthcheckConfig_PermissionsEnforced(t *testing.T) {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
*tt.workspaceRole,
owner.Token,
router,
)
testUserToken = member.Token
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
@@ -293,20 +303,11 @@ func createTestDatabaseViaAPI(
token string,
router *gin.Engine,
) *databases.Database {
testDbName := "test_db"
request := databases.Database{
WorkspaceID: &workspaceID,
Name: name,
Type: databases.DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Database: &testDbName,
CpuCount: 1,
},
Postgresql: databases.GetTestPostgresConfig(),
}
w := workspaces_testing.MakeAPIRequest(

View File

@@ -263,7 +263,12 @@ func (c *NotifierController) TransferNotifierToWorkspace(ctx *gin.Context) {
return
}
if err := c.notifierService.TransferNotifierToWorkspace(user, id, request.TargetWorkspaceID, nil); err != nil {
if err := c.notifierService.TransferNotifierToWorkspace(
user,
id,
request.TargetWorkspaceID,
nil,
); err != nil {
if errors.Is(err, ErrInsufficientPermissionsInSourceWorkspace) ||
errors.Is(err, ErrInsufficientPermissionsInTargetWorkspace) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})

View File

@@ -675,6 +675,10 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
WebhookNotifier: &webhook_notifier.WebhookNotifier{
WebhookURL: "https://webhook.example.com/test",
WebhookMethod: webhook_notifier.WebhookMethodPOST,
Headers: []webhook_notifier.WebhookHeader{
{Key: "Authorization", Value: "Bearer my-secret-token"},
{Key: "X-Custom-Header", Value: "custom-value"},
},
},
}
},
@@ -687,14 +691,40 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
WebhookNotifier: &webhook_notifier.WebhookNotifier{
WebhookURL: "https://webhook.example.com/updated",
WebhookMethod: webhook_notifier.WebhookMethodGET,
Headers: []webhook_notifier.WebhookHeader{
{Key: "Authorization", Value: "Bearer updated-token"},
},
},
}
},
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
// No sensitive data to verify for webhook
assert.NotEmpty(
t,
notifier.WebhookNotifier.WebhookURL,
"WebhookURL should be visible",
)
// Verify header values are encrypted in DB
assert.True(
t,
isEncrypted(notifier.WebhookNotifier.Headers[0].Value),
"Header value should be encrypted in DB",
)
decrypted := decryptField(
t,
notifier.ID,
notifier.WebhookNotifier.Headers[0].Value,
)
assert.Equal(t, "Bearer updated-token", decrypted)
},
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
// No sensitive data to hide for webhook
assert.NotEmpty(
t,
notifier.WebhookNotifier.WebhookURL,
"WebhookURL should be visible",
)
for _, header := range notifier.WebhookNotifier.Headers {
assert.Empty(t, header.Value, "Header value should be hidden")
}
},
},
}
@@ -905,7 +935,7 @@ func Test_CreateNotifier_AllSensitiveFieldsEncryptedInDB(t *testing.T) {
},
},
{
name: "Webhook Notifier - WebhookURL encrypted",
name: "Webhook Notifier - Header values encrypted, URL not encrypted",
createNotifier: func(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
@@ -914,17 +944,48 @@ func Test_CreateNotifier_AllSensitiveFieldsEncryptedInDB(t *testing.T) {
WebhookNotifier: &webhook_notifier.WebhookNotifier{
WebhookURL: "https://webhook.example.com/test456",
WebhookMethod: webhook_notifier.WebhookMethodPOST,
Headers: []webhook_notifier.WebhookHeader{
{Key: "Authorization", Value: "Bearer secret-token-12345"},
{Key: "X-API-Key", Value: "api-key-67890"},
},
},
}
},
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
assert.True(
assert.False(
t,
isEncrypted(notifier.WebhookNotifier.WebhookURL),
"WebhookURL should be encrypted",
"WebhookURL should NOT be encrypted",
)
decrypted := decryptField(t, notifier.ID, notifier.WebhookNotifier.WebhookURL)
assert.Equal(t, "https://webhook.example.com/test456", decrypted)
assert.Equal(
t,
"https://webhook.example.com/test456",
notifier.WebhookNotifier.WebhookURL,
)
assert.True(
t,
isEncrypted(notifier.WebhookNotifier.Headers[0].Value),
"Header value should be encrypted",
)
decrypted1 := decryptField(
t,
notifier.ID,
notifier.WebhookNotifier.Headers[0].Value,
)
assert.Equal(t, "Bearer secret-token-12345", decrypted1)
assert.True(
t,
isEncrypted(notifier.WebhookNotifier.Headers[1].Value),
"Header value should be encrypted",
)
decrypted2 := decryptField(
t,
notifier.ID,
notifier.WebhookNotifier.Headers[1].Value,
)
assert.Equal(t, "api-key-67890", decrypted2)
},
},
}
@@ -1050,8 +1111,20 @@ func Test_TransferNotifier_PermissionsEnforced(t *testing.T) {
testUserToken = admin.Token
} else if tt.sourceRole != nil {
testUser := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(sourceWorkspace, testUser, *tt.sourceRole, sourceOwner.Token, router)
workspaces_testing.AddMemberToWorkspace(targetWorkspace, testUser, *tt.targetRole, targetOwner.Token, router)
workspaces_testing.AddMemberToWorkspace(
sourceWorkspace,
testUser,
*tt.sourceRole,
sourceOwner.Token,
router,
)
workspaces_testing.AddMemberToWorkspace(
targetWorkspace,
testUser,
*tt.targetRole,
targetOwner.Token,
router,
)
testUserToken = testUser.Token
}

View File

@@ -21,6 +21,10 @@ type WebhookHeader struct {
Value string `json:"value"`
}
// Before both WebhookURL, BodyTemplate and HeadersJSON were considered
// as sensetive data and it was causing issues. Now only headers values
// considered as sensetive data, but we try to decrypt webhook URL and
// body template for backward combability
type WebhookNotifier struct {
NotifierID uuid.UUID `json:"notifierId" gorm:"primaryKey;column:notifier_id"`
WebhookURL string `json:"webhookUrl" gorm:"not null;column:webhook_url"`
@@ -58,6 +62,20 @@ func (t *WebhookNotifier) AfterFind(_ *gorm.DB) error {
}
}
encryptor := encryption.GetFieldEncryptor()
if t.WebhookURL != "" {
if decrypted, err := encryptor.Decrypt(t.NotifierID, t.WebhookURL); err == nil {
t.WebhookURL = decrypted
}
}
if t.BodyTemplate != nil && *t.BodyTemplate != "" {
if decrypted, err := encryptor.Decrypt(t.NotifierID, *t.BodyTemplate); err == nil {
t.BodyTemplate = &decrypted
}
}
return nil
}
@@ -79,22 +97,24 @@ func (t *WebhookNotifier) Send(
heading string,
message string,
) error {
webhookURL, err := encryptor.Decrypt(t.NotifierID, t.WebhookURL)
if err != nil {
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
if err := t.decryptHeadersForSending(encryptor); err != nil {
return err
}
switch t.WebhookMethod {
case WebhookMethodGET:
return t.sendGET(webhookURL, heading, message, logger)
return t.sendGET(t.WebhookURL, heading, message, logger)
case WebhookMethodPOST:
return t.sendPOST(webhookURL, heading, message, logger)
return t.sendPOST(t.WebhookURL, heading, message, logger)
default:
return fmt.Errorf("unsupported webhook method: %s", t.WebhookMethod)
}
}
func (t *WebhookNotifier) HideSensitiveData() {
for i := range t.Headers {
t.Headers[i].Value = ""
}
}
func (t *WebhookNotifier) Update(incoming *WebhookNotifier) {
@@ -105,14 +125,15 @@ func (t *WebhookNotifier) Update(incoming *WebhookNotifier) {
}
func (t *WebhookNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if t.WebhookURL != "" {
encrypted, err := encryptor.Encrypt(t.NotifierID, t.WebhookURL)
for i := range t.Headers {
if t.Headers[i].Value != "" {
encrypted, err := encryptor.Encrypt(t.NotifierID, t.Headers[i].Value)
if err != nil {
return fmt.Errorf("failed to encrypt header value: %w", err)
}
if err != nil {
return fmt.Errorf("failed to encrypt webhook URL: %w", err)
t.Headers[i].Value = encrypted
}
t.WebhookURL = encrypted
}
return nil
@@ -241,3 +262,15 @@ func escapeJSONString(s string) string {
return string(b[1 : len(b)-1])
}
func (t *WebhookNotifier) decryptHeadersForSending(encryptor encryption.FieldEncryptor) error {
for i := range t.Headers {
if t.Headers[i].Value != "" {
if decrypted, err := encryptor.Decrypt(t.NotifierID, t.Headers[i].Value); err == nil {
t.Headers[i].Value = decrypted
}
}
}
return nil
}

View File

@@ -1,6 +1,7 @@
package restores
import (
"context"
"databasus-backend/internal/features/restores/enums"
"log/slog"
)
@@ -10,7 +11,7 @@ type RestoreBackgroundService struct {
logger *slog.Logger
}
func (s *RestoreBackgroundService) Run() {
func (s *RestoreBackgroundService) Run(ctx context.Context) {
if err := s.failRestoresInProgress(); err != nil {
s.logger.Error("Failed to fail restores in progress", "error", err)
panic(err)

View File

@@ -7,6 +7,7 @@ import (
"io"
"log/slog"
"net/http"
"strconv"
"strings"
"testing"
"time"
@@ -15,8 +16,10 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/mysql"
@@ -272,7 +275,7 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
var backup *backups.Backup
var backup *backups_core.Backup
var request RestoreBackupRequest
if tc.dbType == databases.DatabaseTypePostgres {
@@ -288,7 +291,12 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
},
}
} else {
mysqlDB := createTestMySQLDatabase("Test MySQL DB", workspace.ID, owner.Token, router)
mysqlDB := createTestMySQLDatabase(
"Test MySQL DB",
workspace.ID,
owner.Token,
router,
)
storage := createTestStorage(workspace.ID)
configService := backups_config.GetBackupConfigService()
@@ -314,7 +322,7 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
}
// Set huge backup size (10 TB) that would fail disk validation if checked
repo := &backups.BackupRepository{}
repo := &backups_core.BackupRepository{}
backup.BackupSizeMb = 10485760.0
err := repo.Save(backup)
assert.NoError(t, err)
@@ -361,7 +369,7 @@ func createTestDatabaseWithBackupForRestore(
workspace *workspaces_models.Workspace,
owner *users_dto.SignInResponseDTO,
router *gin.Engine,
) (*databases.Database, *backups.Backup) {
) (*databases.Database, *backups_core.Backup) {
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
@@ -390,20 +398,11 @@ func createTestDatabase(
token string,
router *gin.Engine,
) *databases.Database {
testDbName := "test_db"
request := databases.Database{
WorkspaceID: &workspaceID,
Name: name,
Type: databases.DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Database: &testDbName,
CpuCount: 1,
},
Postgresql: databases.GetTestPostgresConfig(),
}
w := workspaces_testing.MakeAPIRequest(
@@ -434,7 +433,18 @@ func createTestMySQLDatabase(
token string,
router *gin.Engine,
) *databases.Database {
testDbName := "test_db"
env := config.GetEnv()
portStr := env.TestMysql80Port
if portStr == "" {
portStr = "33080"
}
port, err := strconv.Atoi(portStr)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_MYSQL_80_PORT: %v", err))
}
testDbName := "testdb"
request := databases.Database{
WorkspaceID: &workspaceID,
Name: name,
@@ -442,9 +452,9 @@ func createTestMySQLDatabase(
Mysql: &mysql.MysqlDatabase{
Version: tools.MysqlVersion80,
Host: "localhost",
Port: 3306,
Username: "root",
Password: "password",
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
},
}
@@ -495,7 +505,7 @@ func createTestStorage(workspaceID uuid.UUID) *storages.Storage {
func createTestBackup(
database *databases.Database,
owner *users_dto.SignInResponseDTO,
) *backups.Backup {
) *backups_core.Backup {
fieldEncryptor := util_encryption.GetFieldEncryptor()
userService := users_services.GetUserService()
user, err := userService.GetUserFromToken(owner.Token)
@@ -508,17 +518,17 @@ func createTestBackup(
panic("No storage found for workspace")
}
backup := &backups.Backup{
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storages[0].ID,
Status: backups.BackupStatusCompleted,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10.5,
BackupDurationMs: 1000,
CreatedAt: time.Now().UTC(),
}
repo := &backups.BackupRepository{}
repo := &backups_core.BackupRepository{}
if err := repo.Save(backup); err != nil {
panic(err)
}
@@ -526,7 +536,13 @@ func createTestBackup(
dummyContent := []byte("dummy backup content for testing")
reader := strings.NewReader(string(dummyContent))
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
if err := storages[0].SaveFile(context.Background(), fieldEncryptor, logger, backup.ID, reader); err != nil {
if err := storages[0].SaveFile(
context.Background(),
fieldEncryptor,
logger,
backup.ID,
reader,
); err != nil {
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
}

View File

@@ -1,7 +1,7 @@
package models
import (
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/restores/enums"
"time"
@@ -13,7 +13,7 @@ type Restore struct {
Status enums.RestoreStatus `json:"status" gorm:"column:status;type:text;not null"`
BackupID uuid.UUID `json:"backupId" gorm:"column:backup_id;type:uuid;not null"`
Backup *backups.Backup
Backup *backups_core.Backup
FailMessage *string `json:"failMessage" gorm:"column:fail_message"`

View File

@@ -3,6 +3,7 @@ package restores
import (
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/disk"
@@ -36,7 +37,7 @@ type RestoreService struct {
diskService *disk.DiskService
}
func (s *RestoreService) OnBeforeBackupRemove(backup *backups.Backup) error {
func (s *RestoreService) OnBeforeBackupRemove(backup *backups_core.Backup) error {
restores, err := s.restoreRepository.FindByBackupID(backup.ID)
if err != nil {
return err
@@ -153,10 +154,10 @@ func (s *RestoreService) RestoreBackupWithAuth(
}
func (s *RestoreService) RestoreBackup(
backup *backups.Backup,
backup *backups_core.Backup,
requestDTO RestoreBackupRequest,
) error {
if backup.Status != backups.BackupStatusCompleted {
if backup.Status != backups_core.BackupStatusCompleted {
return errors.New("backup is not completed")
}
@@ -229,8 +230,8 @@ func (s *RestoreService) RestoreBackup(
Mongodb: requestDTO.MongodbDatabase,
}
if err := restoringToDB.PopulateVersionIfEmpty(s.logger, s.fieldEncryptor); err != nil {
return fmt.Errorf("failed to auto-detect database version: %w", err)
if err := restoringToDB.PopulateDbData(s.logger, s.fieldEncryptor); err != nil {
return fmt.Errorf("failed to auto-detect database data: %w", err)
}
isExcludeExtensions := false
@@ -370,7 +371,7 @@ func (s *RestoreService) validateVersionCompatibility(
}
func (s *RestoreService) validateDiskSpace(
backup *backups.Backup,
backup *backups_core.Backup,
requestDTO RestoreBackupRequest,
) error {
// Only validate disk space for PostgreSQL when file-based restore is needed:

View File

@@ -18,7 +18,7 @@ import (
"github.com/klauspost/compress/zstd"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -40,7 +40,7 @@ func (uc *RestoreMariadbBackupUsecase) Execute(
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
restore models.Restore,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
) error {
if originalDB.Type != databases.DatabaseTypeMariadb {
@@ -71,6 +71,7 @@ func (uc *RestoreMariadbBackupUsecase) Execute(
if mdb.IsHttps {
args = append(args, "--ssl")
args = append(args, "--skip-ssl-verify-server-cert")
}
if mdb.Database != nil && *mdb.Database != "" {
@@ -98,7 +99,7 @@ func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
mariadbBin string,
args []string,
password string,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
mdbConfig *mariadbtypes.MariadbDatabase,
) error {
@@ -162,7 +163,7 @@ func (uc *RestoreMariadbBackupUsecase) executeMariadbRestore(
args []string,
myCnfFile string,
backupReader io.ReadCloser,
backup *backups.Backup,
backup *backups_core.Backup,
) error {
fullArgs := append([]string{"--defaults-file=" + myCnfFile}, args...)
@@ -225,7 +226,7 @@ func (uc *RestoreMariadbBackupUsecase) executeMariadbRestore(
func (uc *RestoreMariadbBackupUsecase) setupDecryption(
reader io.Reader,
backup *backups.Backup,
backup *backups_core.Backup,
) (io.Reader, error) {
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
return nil, fmt.Errorf("backup is encrypted but missing encryption metadata")
@@ -265,11 +266,24 @@ func (uc *RestoreMariadbBackupUsecase) createTempMyCnfFile(
mdbConfig *mariadbtypes.MariadbDatabase,
password string,
) (string, error) {
tempDir, err := os.MkdirTemp(config.GetEnv().TempFolder, "mycnf_"+uuid.New().String())
tempFolder := config.GetEnv().TempFolder
if err := os.MkdirAll(tempFolder, 0700); err != nil {
return "", fmt.Errorf("failed to ensure temp folder exists: %w", err)
}
if err := os.Chmod(tempFolder, 0700); err != nil {
return "", fmt.Errorf("failed to set temp folder permissions: %w", err)
}
tempDir, err := os.MkdirTemp(tempFolder, "mycnf_"+uuid.New().String())
if err != nil {
return "", fmt.Errorf("failed to create temp directory: %w", err)
}
if err := os.Chmod(tempDir, 0700); err != nil {
_ = os.RemoveAll(tempDir)
return "", fmt.Errorf("failed to set temp directory permissions: %w", err)
}
myCnfFile := filepath.Join(tempDir, ".my.cnf")
content := fmt.Sprintf(`[client]
@@ -287,6 +301,7 @@ port=%d
err = os.WriteFile(myCnfFile, []byte(content), 0600)
if err != nil {
_ = os.RemoveAll(tempDir)
return "", fmt.Errorf("failed to write .my.cnf: %w", err)
}

View File

@@ -14,7 +14,7 @@ import (
"time"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -40,7 +40,7 @@ func (uc *RestoreMongodbBackupUsecase) Execute(
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
restore models.Restore,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
) error {
if originalDB.Type != databases.DatabaseTypeMongodb {
@@ -124,7 +124,7 @@ func (uc *RestoreMongodbBackupUsecase) buildMongorestoreArgs(
func (uc *RestoreMongodbBackupUsecase) restoreFromStorage(
mongorestoreBin string,
args []string,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
) error {
ctx, cancel := context.WithTimeout(context.Background(), restoreTimeout)
@@ -166,7 +166,7 @@ func (uc *RestoreMongodbBackupUsecase) executeMongoRestore(
mongorestoreBin string,
args []string,
backupReader io.ReadCloser,
backup *backups.Backup,
backup *backups_core.Backup,
) error {
cmd := exec.CommandContext(ctx, mongorestoreBin, args...)
@@ -231,7 +231,7 @@ func (uc *RestoreMongodbBackupUsecase) executeMongoRestore(
func (uc *RestoreMongodbBackupUsecase) setupDecryption(
reader io.Reader,
backup *backups.Backup,
backup *backups_core.Backup,
) (io.Reader, error) {
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
return nil, errors.New("encrypted backup missing salt or IV")

View File

@@ -18,7 +18,7 @@ import (
"github.com/klauspost/compress/zstd"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -40,7 +40,7 @@ func (uc *RestoreMysqlBackupUsecase) Execute(
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
restore models.Restore,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
) error {
if originalDB.Type != databases.DatabaseTypeMysql {
@@ -98,7 +98,7 @@ func (uc *RestoreMysqlBackupUsecase) restoreFromStorage(
mysqlBin string,
args []string,
password string,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
myConfig *mysqltypes.MysqlDatabase,
) error {
@@ -154,7 +154,7 @@ func (uc *RestoreMysqlBackupUsecase) executeMysqlRestore(
args []string,
myCnfFile string,
backupReader io.ReadCloser,
backup *backups.Backup,
backup *backups_core.Backup,
) error {
fullArgs := append([]string{"--defaults-file=" + myCnfFile}, args...)
@@ -217,7 +217,7 @@ func (uc *RestoreMysqlBackupUsecase) executeMysqlRestore(
func (uc *RestoreMysqlBackupUsecase) setupDecryption(
reader io.Reader,
backup *backups.Backup,
backup *backups_core.Backup,
) (io.Reader, error) {
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
return nil, fmt.Errorf("backup is encrypted but missing encryption metadata")
@@ -257,11 +257,24 @@ func (uc *RestoreMysqlBackupUsecase) createTempMyCnfFile(
myConfig *mysqltypes.MysqlDatabase,
password string,
) (string, error) {
tempDir, err := os.MkdirTemp(config.GetEnv().TempFolder, "mycnf_"+uuid.New().String())
tempFolder := config.GetEnv().TempFolder
if err := os.MkdirAll(tempFolder, 0700); err != nil {
return "", fmt.Errorf("failed to ensure temp folder exists: %w", err)
}
if err := os.Chmod(tempFolder, 0700); err != nil {
return "", fmt.Errorf("failed to set temp folder permissions: %w", err)
}
tempDir, err := os.MkdirTemp(tempFolder, "mycnf_"+uuid.New().String())
if err != nil {
return "", fmt.Errorf("failed to create temp directory: %w", err)
}
if err := os.Chmod(tempDir, 0700); err != nil {
_ = os.RemoveAll(tempDir)
return "", fmt.Errorf("failed to set temp directory permissions: %w", err)
}
myCnfFile := filepath.Join(tempDir, ".my.cnf")
content := fmt.Sprintf(`[client]
@@ -277,6 +290,7 @@ port=%d
err = os.WriteFile(myCnfFile, []byte(content), 0600)
if err != nil {
_ = os.RemoveAll(tempDir)
return "", fmt.Errorf("failed to write .my.cnf: %w", err)
}

View File

@@ -15,7 +15,7 @@ import (
"time"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -39,7 +39,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
restore models.Restore,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
isExcludeExtensions bool,
) error {
@@ -86,7 +86,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
func (uc *RestorePostgresqlBackupUsecase) restoreCustomType(
originalDB *databases.Database,
pgBin string,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
pg *pgtypes.PostgresqlDatabase,
isExcludeExtensions bool,
@@ -113,7 +113,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreCustomType(
func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
originalDB *databases.Database,
pgBin string,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
pg *pgtypes.PostgresqlDatabase,
) error {
@@ -321,7 +321,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
func (uc *RestorePostgresqlBackupUsecase) restoreViaFile(
originalDB *databases.Database,
pgBin string,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
pg *pgtypes.PostgresqlDatabase,
isExcludeExtensions bool,
@@ -371,7 +371,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
pgBin string,
args []string,
password string,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
pgConfig *pgtypes.PostgresqlDatabase,
isExcludeExtensions bool,
@@ -469,7 +469,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
// downloadBackupToTempFile downloads backup data from storage to a temporary file
func (uc *RestorePostgresqlBackupUsecase) downloadBackupToTempFile(
ctx context.Context,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
) (string, func(), error) {
// Create temporary directory for backup data
@@ -925,14 +925,28 @@ func (uc *RestorePostgresqlBackupUsecase) createTempPgpassFile(
escapedPassword,
)
tempDir, err := os.MkdirTemp(config.GetEnv().TempFolder, "pgpass_"+uuid.New().String())
tempFolder := config.GetEnv().TempFolder
if err := os.MkdirAll(tempFolder, 0700); err != nil {
return "", fmt.Errorf("failed to ensure temp folder exists: %w", err)
}
if err := os.Chmod(tempFolder, 0700); err != nil {
return "", fmt.Errorf("failed to set temp folder permissions: %w", err)
}
tempDir, err := os.MkdirTemp(tempFolder, "pgpass_"+uuid.New().String())
if err != nil {
return "", fmt.Errorf("failed to create temporary directory: %w", err)
}
if err := os.Chmod(tempDir, 0700); err != nil {
_ = os.RemoveAll(tempDir)
return "", fmt.Errorf("failed to set temporary directory permissions: %w", err)
}
pgpassFile := filepath.Join(tempDir, ".pgpass")
err = os.WriteFile(pgpassFile, []byte(pgpassContent), 0600)
if err != nil {
_ = os.RemoveAll(tempDir)
return "", fmt.Errorf("failed to write temporary .pgpass file: %w", err)
}

View File

@@ -3,7 +3,7 @@ package usecases
import (
"errors"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/restores/models"
@@ -26,7 +26,7 @@ func (uc *RestoreBackupUsecase) Execute(
restore models.Restore,
originalDB *databases.Database,
restoringToDB *databases.Database,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
isExcludeExtensions bool,
) error {

View File

@@ -263,7 +263,12 @@ func (c *StorageController) TransferStorageToWorkspace(ctx *gin.Context) {
return
}
if err := c.storageService.TransferStorageToWorkspace(user, id, request.TargetWorkspaceID, nil); err != nil {
if err := c.storageService.TransferStorageToWorkspace(
user,
id,
request.TargetWorkspaceID,
nil,
); err != nil {
if errors.Is(err, ErrInsufficientPermissionsInSourceWorkspace) ||
errors.Is(err, ErrInsufficientPermissionsInTargetWorkspace) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})

View File

@@ -1071,8 +1071,20 @@ func Test_TransferStorage_PermissionsEnforced(t *testing.T) {
testUserToken = admin.Token
} else if tt.sourceRole != nil {
testUser := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(sourceWorkspace, testUser, *tt.sourceRole, sourceOwner.Token, router)
workspaces_testing.AddMemberToWorkspace(targetWorkspace, testUser, *tt.targetRole, targetOwner.Token, router)
workspaces_testing.AddMemberToWorkspace(
sourceWorkspace,
testUser,
*tt.sourceRole,
sourceOwner.Token,
router,
)
workspaces_testing.AddMemberToWorkspace(
targetWorkspace,
testUser,
*tt.targetRole,
targetOwner.Token,
router,
)
testUserToken = testUser.Token
}

View File

@@ -26,6 +26,7 @@ const (
azureResponseTimeout = 30 * time.Second
azureIdleConnTimeout = 90 * time.Second
azureTLSHandshakeTimeout = 30 * time.Second
azureDeleteTimeout = 30 * time.Second
// Chunk size for block blob uploads - 16MB provides good balance between
// memory usage and upload efficiency. This creates backpressure to pg_dump
@@ -186,8 +187,11 @@ func (s *AzureBlobStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileI
blobName := s.buildBlobName(fileID.String())
ctx, cancel := context.WithTimeout(context.Background(), azureDeleteTimeout)
defer cancel()
_, err = client.DeleteBlob(
context.TODO(),
ctx,
s.ContainerName,
blobName,
nil,

View File

@@ -18,6 +18,7 @@ import (
const (
ftpConnectTimeout = 30 * time.Second
ftpTestConnectTimeout = 10 * time.Second
ftpDeleteTimeout = 30 * time.Second
ftpChunkSize = 16 * 1024 * 1024
)
@@ -134,7 +135,10 @@ func (f *FTPStorage) GetFile(
}
func (f *FTPStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
conn, err := f.connect(encryptor, ftpConnectTimeout)
ctx, cancel := context.WithTimeout(context.Background(), ftpDeleteTimeout)
defer cancel()
conn, err := f.connectWithContext(ctx, encryptor, ftpDeleteTimeout)
if err != nil {
return fmt.Errorf("failed to connect to FTP: %w", err)
}

View File

@@ -27,6 +27,7 @@ const (
gdResponseTimeout = 30 * time.Second
gdIdleConnTimeout = 90 * time.Second
gdTLSHandshakeTimeout = 30 * time.Second
gdDeleteTimeout = 30 * time.Second
// Chunk size for Google Drive resumable uploads - 16MB provides good balance
// between memory usage and upload efficiency. Google Drive requires chunks
@@ -185,7 +186,9 @@ func (s *GoogleDriveStorage) DeleteFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
) error {
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), gdDeleteTimeout)
defer cancel()
return s.withRetryOnAuth(ctx, encryptor, func(driveService *drive.Service) error {
folderID, err := s.findBackupsFolder(driveService)
if err != nil {

View File

@@ -18,6 +18,8 @@ import (
)
const (
nasDeleteTimeout = 30 * time.Second
// Chunk size for NAS uploads - 16MB provides good balance between
// memory usage and upload efficiency. This creates backpressure to pg_dump
// by only reading one chunk at a time and waiting for NAS to confirm receipt.
@@ -193,7 +195,10 @@ func (n *NASStorage) GetFile(
}
func (n *NASStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
session, err := n.createSession(encryptor)
ctx, cancel := context.WithTimeout(context.Background(), nasDeleteTimeout)
defer cancel()
session, err := n.createSessionWithContext(ctx, encryptor)
if err != nil {
return fmt.Errorf("failed to create NAS session: %w", err)
}
@@ -211,10 +216,8 @@ func (n *NASStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid
filePath := n.getFilePath(fileID.String())
// Check if file exists before trying to delete
_, err = fs.Stat(filePath)
if err != nil {
// File doesn't exist, consider it already deleted
return nil
}

View File

@@ -22,6 +22,7 @@ import (
const (
rcloneOperationTimeout = 30 * time.Second
rcloneDeleteTimeout = 30 * time.Second
)
var rcloneConfigMu sync.Mutex
@@ -115,7 +116,8 @@ func (r *RcloneStorage) GetFile(
}
func (r *RcloneStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), rcloneDeleteTimeout)
defer cancel()
remoteFs, err := r.getFs(ctx, encryptor)
if err != nil {

View File

@@ -26,6 +26,7 @@ const (
s3ResponseTimeout = 30 * time.Second
s3IdleConnTimeout = 90 * time.Second
s3TLSHandshakeTimeout = 30 * time.Second
s3DeleteTimeout = 30 * time.Second
// Chunk size for multipart uploads - 16MB provides good balance between
// memory usage and upload efficiency. This creates backpressure to pg_dump
@@ -228,9 +229,11 @@ func (s *S3Storage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.
objectKey := s.buildObjectKey(fileID.String())
// Delete the object using MinIO client
ctx, cancel := context.WithTimeout(context.Background(), s3DeleteTimeout)
defer cancel()
err = client.RemoveObject(
context.TODO(),
ctx,
s.S3Bucket,
objectKey,
minio.RemoveObjectOptions{},

View File

@@ -19,6 +19,7 @@ import (
const (
sftpConnectTimeout = 30 * time.Second
sftpTestConnectTimeout = 10 * time.Second
sftpDeleteTimeout = 30 * time.Second
)
type SFTPStorage struct {
@@ -154,7 +155,10 @@ func (s *SFTPStorage) GetFile(
}
func (s *SFTPStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
client, sshConn, err := s.connect(encryptor, sftpConnectTimeout)
ctx, cancel := context.WithTimeout(context.Background(), sftpDeleteTimeout)
defer cancel()
client, sshConn, err := s.connectWithContext(ctx, encryptor, sftpDeleteTimeout)
if err != nil {
return fmt.Errorf("failed to connect to SFTP: %w", err)
}

View File

@@ -1,13 +1,14 @@
package system_healthcheck
import (
"databasus-backend/internal/features/backups/backups"
"databasus-backend/internal/features/backups/backups/backuping"
"databasus-backend/internal/features/disk"
)
var healthcheckService = &HealthcheckService{
disk.GetDiskService(),
backups.GetBackupBackgroundService(),
backuping.GetBackupsScheduler(),
backuping.GetBackuperNode(),
}
var healthcheckController = &HealthcheckController{
healthcheckService,

View File

@@ -1,7 +1,8 @@
package system_healthcheck
import (
"databasus-backend/internal/features/backups/backups"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups/backuping"
"databasus-backend/internal/features/disk"
"databasus-backend/internal/storage"
"errors"
@@ -9,7 +10,8 @@ import (
type HealthcheckService struct {
diskService *disk.DiskService
backupBackgroundService *backups.BackupBackgroundService
backupBackgroundService *backuping.BackupsScheduler
backuperNode *backuping.BackuperNode
}
func (s *HealthcheckService) IsHealthy() error {
@@ -29,8 +31,16 @@ func (s *HealthcheckService) IsHealthy() error {
return errors.New("cannot connect to the database")
}
if !s.backupBackgroundService.IsBackupsWorkerRunning() {
return errors.New("backups are not running for more than 5 minutes")
if config.GetEnv().IsPrimaryNode {
if !s.backupBackgroundService.IsSchedulerRunning() {
return errors.New("backups are not running for more than 5 minutes")
}
}
if config.GetEnv().IsBackupNode {
if !s.backuperNode.IsBackuperRunning() {
return errors.New("backuper node is not running for more than 5 minutes")
}
}
return nil

View File

@@ -17,7 +17,7 @@ import (
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
mariadbtypes "databasus-backend/internal/features/databases/databases/mariadb"
@@ -189,7 +189,7 @@ func testMariadbBackupRestoreForVersion(
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb_mariadb"
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -286,7 +286,7 @@ func testMariadbBackupRestoreWithEncryptionForVersion(
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_config.BackupEncryptionEncrypted, backup.Encryption)
newDBName := "restoreddb_mariadb_encrypted"
@@ -394,7 +394,7 @@ func testMariadbBackupRestoreWithReadOnlyUserForVersion(
createBackupViaAPI(t, router, updatedDatabase.ID, user.Token)
backup := waitForBackupCompletion(t, router, updatedDatabase.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb_mariadb_readonly"
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))

View File

@@ -19,7 +19,7 @@ import (
"go.mongodb.org/mongo-driver/mongo/options"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
mongodbtypes "databasus-backend/internal/features/databases/databases/mongodb"
@@ -161,7 +161,7 @@ func testMongodbBackupRestoreForVersion(
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb_mongodb_" + uuid.New().String()[:8]
@@ -239,7 +239,7 @@ func testMongodbBackupRestoreWithEncryptionForVersion(
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_config.BackupEncryptionEncrypted, backup.Encryption)
newDBName := "restoreddb_mongodb_enc_" + uuid.New().String()[:8]
@@ -328,7 +328,7 @@ func testMongodbBackupRestoreWithReadOnlyUserForVersion(
createBackupViaAPI(t, router, updatedDatabase.ID, user.Token)
backup := waitForBackupCompletion(t, router, updatedDatabase.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb_mongodb_ro_" + uuid.New().String()[:8]

View File

@@ -17,7 +17,7 @@ import (
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
mysqltypes "databasus-backend/internal/features/databases/databases/mysql"
@@ -164,7 +164,7 @@ func testMysqlBackupRestoreForVersion(t *testing.T, mysqlVersion tools.MysqlVers
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb_mysql"
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -261,7 +261,7 @@ func testMysqlBackupRestoreWithEncryptionForVersion(
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_config.BackupEncryptionEncrypted, backup.Encryption)
newDBName := "restoreddb_mysql_encrypted"
@@ -369,7 +369,7 @@ func testMysqlBackupRestoreWithReadOnlyUserForVersion(
createBackupViaAPI(t, router, updatedDatabase.ID, user.Token)
backup := waitForBackupCompletion(t, router, updatedDatabase.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb_mysql_readonly"
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))

View File

@@ -18,6 +18,7 @@ import (
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
pgtypes "databasus-backend/internal/features/databases/databases/postgresql"
@@ -32,21 +33,23 @@ import (
test_utils "databasus-backend/internal/util/testing"
)
const createAndFillTableQuery = `
DROP TABLE IF EXISTS test_data;
func createAndFillTableQuery(tableName string) string {
return fmt.Sprintf(`
DROP TABLE IF EXISTS %s;
CREATE TABLE test_data (
CREATE TABLE %s (
id SERIAL PRIMARY KEY,
name TEXT NOT NULL,
value INTEGER NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
INSERT INTO test_data (name, value) VALUES
INSERT INTO %s (name, value) VALUES
('test1', 100),
('test2', 200),
('test3', 300);
`
`, tableName, tableName, tableName)
}
type PostgresContainer struct {
Host string
@@ -188,7 +191,7 @@ func Test_BackupAndRestoreSupabase_PublicSchemaOnly_RestoreIsSuccessful(t *testi
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
_, err = supabaseDB.Exec(fmt.Sprintf(`DELETE FROM public.%s`, tableName))
assert.NoError(t, err)
@@ -378,9 +381,14 @@ func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string, cp
}
}()
_, err = container.DB.Exec(createAndFillTableQuery)
tableName := fmt.Sprintf("test_data_%s", uuid.New().String()[:8])
_, err = container.DB.Exec(createAndFillTableQuery(tableName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s;", tableName))
}()
router := createTestRouter()
user := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
@@ -403,7 +411,7 @@ func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string, cp
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := fmt.Sprintf("restoreddb_%s_cpu%d_%s", pgVersion, cpuCount, uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -436,12 +444,19 @@ func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string, cp
var tableExists bool
err = newDB.Get(
&tableExists,
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'test_data')",
fmt.Sprintf(
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = '%s')",
tableName,
),
)
assert.NoError(t, err)
assert.True(t, tableExists, "Table 'test_data' should exist in restored database")
assert.True(
t,
tableExists,
fmt.Sprintf("Table '%s' should exist in restored database", tableName),
)
verifyDataIntegrity(t, container.DB, newDB)
verifyDataIntegrity(t, container.DB, newDB, tableName)
err = os.Remove(filepath.Join(config.GetEnv().DataFolder, backup.ID.String()))
if err != nil {
@@ -513,7 +528,7 @@ func testSchemaSelectionAllSchemasForVersion(t *testing.T, pgVersion string, por
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := fmt.Sprintf("restored_all_schemas_%s_%s", pgVersion, uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -641,7 +656,7 @@ func testBackupRestoreWithExcludeExtensionsForVersion(t *testing.T, pgVersion st
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := fmt.Sprintf("restored_exclude_ext_%s_%s", pgVersion, uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -775,7 +790,7 @@ func testBackupRestoreWithoutExcludeExtensionsForVersion(
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := fmt.Sprintf("restored_with_ext_%s_%s", pgVersion, uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -875,9 +890,14 @@ func testBackupRestoreWithReadOnlyUserForVersion(t *testing.T, pgVersion string,
}
}()
_, err = container.DB.Exec(createAndFillTableQuery)
tableName := fmt.Sprintf("test_data_%s", uuid.New().String()[:8])
_, err = container.DB.Exec(createAndFillTableQuery(tableName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s;", tableName))
}()
router := createTestRouter()
user := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("ReadOnly Test Workspace", user, router)
@@ -909,7 +929,7 @@ func testBackupRestoreWithReadOnlyUserForVersion(t *testing.T, pgVersion string,
createBackupViaAPI(t, router, updatedDatabase.ID, user.Token)
backup := waitForBackupCompletion(t, router, updatedDatabase.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := fmt.Sprintf("restoreddb_readonly_%s", uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -941,12 +961,19 @@ func testBackupRestoreWithReadOnlyUserForVersion(t *testing.T, pgVersion string,
var tableExists bool
err = newDB.Get(
&tableExists,
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'test_data')",
fmt.Sprintf(
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = '%s')",
tableName,
),
)
assert.NoError(t, err)
assert.True(t, tableExists, "Table 'test_data' should exist in restored database")
assert.True(
t,
tableExists,
fmt.Sprintf("Table '%s' should exist in restored database", tableName),
)
verifyDataIntegrity(t, container.DB, newDB)
verifyDataIntegrity(t, container.DB, newDB, tableName)
err = os.Remove(filepath.Join(config.GetEnv().DataFolder, backup.ID.String()))
if err != nil {
@@ -1022,7 +1049,7 @@ func testSchemaSelectionOnlySpecifiedSchemasForVersion(
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := fmt.Sprintf("restored_specific_schemas_%s_%s", pgVersion, uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -1106,9 +1133,14 @@ func testBackupRestoreWithEncryptionForVersion(t *testing.T, pgVersion string, p
}
}()
_, err = container.DB.Exec(createAndFillTableQuery)
tableName := fmt.Sprintf("test_data_%s", uuid.New().String()[:8])
_, err = container.DB.Exec(createAndFillTableQuery(tableName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s;", tableName))
}()
router := createTestRouter()
user := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
@@ -1130,7 +1162,7 @@ func testBackupRestoreWithEncryptionForVersion(t *testing.T, pgVersion string, p
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_config.BackupEncryptionEncrypted, backup.Encryption)
newDBName := fmt.Sprintf("restoreddb_encrypted_%s", uuid.New().String()[:8])
@@ -1163,12 +1195,19 @@ func testBackupRestoreWithEncryptionForVersion(t *testing.T, pgVersion string, p
var tableExists bool
err = newDB.Get(
&tableExists,
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'test_data')",
fmt.Sprintf(
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = '%s')",
tableName,
),
)
assert.NoError(t, err)
assert.True(t, tableExists, "Table 'test_data' should exist in restored database")
assert.True(
t,
tableExists,
fmt.Sprintf("Table '%s' should exist in restored database", tableName),
)
verifyDataIntegrity(t, container.DB, newDB)
verifyDataIntegrity(t, container.DB, newDB, tableName)
err = os.Remove(filepath.Join(config.GetEnv().DataFolder, backup.ID.String()))
if err != nil {
@@ -1204,7 +1243,7 @@ func waitForBackupCompletion(
databaseID uuid.UUID,
token string,
timeout time.Duration,
) *backups.Backup {
) *backups_core.Backup {
startTime := time.Now()
pollInterval := 500 * time.Millisecond
@@ -1225,10 +1264,10 @@ func waitForBackupCompletion(
if len(response.Backups) > 0 {
backup := response.Backups[0]
if backup.Status == backups.BackupStatusCompleted {
if backup.Status == backups_core.BackupStatusCompleted {
return backup
}
if backup.Status == backups.BackupStatusFailed {
if backup.Status == backups_core.BackupStatusFailed {
failMsg := "unknown error"
if backup.FailMessage != nil {
failMsg = *backup.FailMessage
@@ -1630,14 +1669,14 @@ func createSupabaseRestoreViaAPI(
)
}
func verifyDataIntegrity(t *testing.T, originalDB *sqlx.DB, restoredDB *sqlx.DB) {
func verifyDataIntegrity(t *testing.T, originalDB *sqlx.DB, restoredDB *sqlx.DB, tableName string) {
var originalData []TestDataItem
var restoredData []TestDataItem
err := originalDB.Select(&originalData, "SELECT * FROM test_data ORDER BY id")
err := originalDB.Select(&originalData, fmt.Sprintf("SELECT * FROM %s ORDER BY id", tableName))
assert.NoError(t, err)
err = restoredDB.Select(&restoredData, "SELECT * FROM test_data ORDER BY id")
err = restoredDB.Select(&restoredData, fmt.Sprintf("SELECT * FROM %s ORDER BY id", tableName))
assert.NoError(t, err)
assert.Equal(t, len(originalData), len(restoredData), "Should have same number of rows")

View File

@@ -0,0 +1,22 @@
package tests
import (
"os"
"testing"
"databasus-backend/internal/features/backups/backups/backuping"
cache_utils "databasus-backend/internal/util/cache"
)
func TestMain(m *testing.M) {
cache_utils.ClearAllCache()
backuperNode := backuping.CreateTestBackuperNode()
cancel := backuping.StartBackuperNodeForTest(&testing.T{}, backuperNode)
exitCode := m.Run()
backuping.StopBackuperNodeForTest(&testing.T{}, cancel, backuperNode)
os.Exit(exitCode)
}

View File

@@ -56,11 +56,17 @@ func (s *UserService) SignUp(request *users_dto.SignUpRequestDTO) error {
// If user exists with INVITED status, activate them and set password
if existingUser != nil && existingUser.Status == users_enums.UserStatusInvited {
if err := s.userRepository.UpdateUserPassword(existingUser.ID, hashedPasswordStr); err != nil {
if err := s.userRepository.UpdateUserPassword(
existingUser.ID,
hashedPasswordStr,
); err != nil {
return fmt.Errorf("failed to set password: %w", err)
}
if err := s.userRepository.UpdateUserStatus(existingUser.ID, users_enums.UserStatusActive); err != nil {
if err := s.userRepository.UpdateUserStatus(
existingUser.ID,
users_enums.UserStatusActive,
); err != nil {
return fmt.Errorf("failed to activate user: %w", err)
}
@@ -635,7 +641,10 @@ func (s *UserService) getOrCreateUserFromOAuth(
if userByEmail != nil {
if userByEmail.Status == users_enums.UserStatusInvited {
if err := s.userRepository.UpdateUserStatus(userByEmail.ID, users_enums.UserStatusActive); err != nil {
if err := s.userRepository.UpdateUserStatus(
userByEmail.ID,
users_enums.UserStatusActive,
); err != nil {
return nil, fmt.Errorf("failed to activate user: %w", err)
}

View File

@@ -161,7 +161,12 @@ func (c *MembershipController) ChangeMemberRole(ctx *gin.Context) {
return
}
if err := c.membershipService.ChangeMemberRole(workspaceID, userID, &request, user); err != nil {
if err := c.membershipService.ChangeMemberRole(
workspaceID,
userID,
&request,
user,
); err != nil {
if errors.Is(err, workspaces_errors.ErrInsufficientPermissionsToManageMembers) ||
errors.Is(err, workspaces_errors.ErrOnlyOwnerCanAddManageAdmins) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})

Some files were not shown because too many files have changed in this diff Show More