Compare commits

...

15 Commits

Author SHA1 Message Date
Rostislav Dugin
dd1072e230 Merge pull request #265 from databasus/develop
FIX (pre-commit): Add running go mod tidy in pre-commit
2026-01-14 15:18:35 +03:00
Rostislav Dugin
a495e5317a FIX (pre-commit): Add running go mod tidy in pre-commit 2026-01-14 15:18:06 +03:00
Rostislav Dugin
7eed647038 Merge pull request #264 from databasus/develop
Develop
2026-01-14 15:14:05 +03:00
Rostislav Dugin
6973241e25 FIX (backups): Throw error on parallel download token generation 2026-01-14 14:40:22 +03:00
Rostislav Dugin
ab181f5b81 FEATURE (bandwidth): Limit download throughput for backups to not exhaust more than 75% of server network bandwidth 2026-01-14 14:40:22 +03:00
Rostislav Dugin
b60a0cc170 FEATURE (backups): Allow single backup download to avoid exhausting of server throughput 2026-01-14 14:40:22 +03:00
Rostislav Dugin
f319a497b3 FEATURE (auth): Add rate limiting for sign in via email using sliding window 2026-01-14 14:40:22 +03:00
Rostislav Dugin
bc870b3f8e Merge pull request #261 from databasus/develop
FIX (webhook): Update webhook tests to not expect URL to be encrypted
2026-01-14 09:43:26 +03:00
Rostislav Dugin
15383c59eb FIX (webhook): Update webhook tests to not expect URL to be encrypted 2026-01-14 09:42:25 +03:00
Rostislav Dugin
d14c223a65 Merge pull request #259 from databasus/develop
Develop
2026-01-14 09:10:28 +03:00
Rostislav Dugin
2c0a294027 FIX (webhook): Do not encypt webhook URL, keep encyption for headers only 2026-01-14 09:09:00 +03:00
Rostislav Dugin
5d851d73bd FIX (mysql \ mariadb): Decrease strictness of SELECT check for health check 2026-01-14 08:39:27 +03:00
Rostislav Dugin
699913c251 FIX (postgresql): Filter TEMP table SELECT checks 2026-01-14 07:42:29 +03:00
Rostislav Dugin
a2e3f30a6d Merge pull request #258 from databasus/develop
FEATURE (backups): Add support of multinode Databasus setup
2026-01-14 07:34:06 +03:00
Rostislav Dugin
80f1174ecd FEATURE (backups): Add support of multinode Databasus setup 2026-01-14 07:32:13 +03:00
75 changed files with 5095 additions and 1054 deletions

View File

@@ -27,3 +27,10 @@ repos:
language: system
files: ^backend/.*\.go$
pass_filenames: false
- id: backend-go-mod-tidy
name: Backend Go Mod Tidy
entry: bash -c "cd backend && go mod tidy"
language: system
files: ^backend/.*\.go$
pass_filenames: false

View File

@@ -2,10 +2,10 @@ run:
go run cmd/main.go
test:
go test -p=1 -count=1 -failfast -timeout 10m ./internal/...
go test -p=1 -count=1 -failfast -timeout 15m ./internal/...
lint:
golangci-lint fmt && golangci-lint run
golangci-lint fmt ./cmd/... ./internal/... && golangci-lint run ./cmd/... ./internal/...
migration-create:
goose create $(name) sql

View File

@@ -15,6 +15,9 @@ import (
"databasus-backend/internal/config"
"databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
"databasus-backend/internal/features/backups/backups/backuping"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_download "databasus-backend/internal/features/backups/backups/download"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/disk"
@@ -54,16 +57,23 @@ func main() {
log := logger.GetLogger()
cache_utils.TestCacheConnection()
err := cache_utils.ClearAllCache()
if err != nil {
log.Error("Failed to clear cache", "error", err)
os.Exit(1)
if config.GetEnv().IsPrimaryNode {
err := cache_utils.ClearAllCache()
if err != nil {
log.Error("Failed to clear cache", "error", err)
os.Exit(1)
}
}
runMigrations(log)
if config.GetEnv().IsPrimaryNode {
runMigrations(log)
} else {
log.Info("Skipping migrations (IS_PRIMARY_NODE is false)")
}
// create directories that used for backups and restore
err = files_utils.EnsureDirectories([]string{
err := files_utils.EnsureDirectories([]string{
config.GetEnv().TempFolder,
config.GetEnv().DataFolder,
})
@@ -104,7 +114,9 @@ func main() {
enableCors(ginApp)
setUpRoutes(ginApp)
setUpDependencies()
runBackgroundTasks(log)
mountFrontend(ginApp)
startServerWithGracefulShutdown(log, ginApp)
@@ -227,35 +239,64 @@ func setUpDependencies() {
notifiers.SetupDependencies()
storages.SetupDependencies()
backups_config.SetupDependencies()
backups_cancellation.SetupDependencies()
}
func runBackgroundTasks(log *slog.Logger) {
log.Info("Preparing to run background tasks...")
err := files_utils.CleanFolder(config.GetEnv().TempFolder)
if err != nil {
log.Error("Failed to clean temp folder", "error", err)
// Create context that will be cancelled on shutdown
ctx, cancel := context.WithCancel(context.Background())
// Set up signal handling for graceful shutdown
quit := make(chan os.Signal, 1)
signal.Notify(quit, os.Interrupt, syscall.SIGTERM)
go func() {
<-quit
log.Info("Shutdown signal received, cancelling all background tasks")
cancel()
}()
if config.GetEnv().IsPrimaryNode {
log.Info("Starting primary node background tasks...")
err := files_utils.CleanFolder(config.GetEnv().TempFolder)
if err != nil {
log.Error("Failed to clean temp folder", "error", err)
}
go runWithPanicLogging(log, "backup background service", func() {
backuping.GetBackupsScheduler().Run(ctx)
})
go runWithPanicLogging(log, "restore background service", func() {
restores.GetRestoreBackgroundService().Run(ctx)
})
go runWithPanicLogging(log, "healthcheck attempt background service", func() {
healthcheck_attempt.GetHealthcheckAttemptBackgroundService().Run(ctx)
})
go runWithPanicLogging(log, "audit log cleanup background service", func() {
audit_logs.GetAuditLogBackgroundService().Run(ctx)
})
go runWithPanicLogging(log, "download token cleanup background service", func() {
backups_download.GetDownloadTokenBackgroundService().Run(ctx)
})
} else {
log.Info("Skipping primary node tasks as not primary node")
}
go runWithPanicLogging(log, "backup background service", func() {
backups.GetBackupBackgroundService().Run()
})
if config.GetEnv().IsBackupNode {
log.Info("Starting backup node background tasks...")
go runWithPanicLogging(log, "restore background service", func() {
restores.GetRestoreBackgroundService().Run()
})
go runWithPanicLogging(log, "healthcheck attempt background service", func() {
healthcheck_attempt.GetHealthcheckAttemptBackgroundService().Run()
})
go runWithPanicLogging(log, "audit log cleanup background service", func() {
audit_logs.GetAuditLogBackgroundService().Run()
})
go runWithPanicLogging(log, "download token cleanup background service", func() {
backups.GetDownloadTokenBackgroundService().Run()
})
go runWithPanicLogging(log, "backup node", func() {
backuping.GetBackuperNode().Run(ctx)
})
} else {
log.Info("Skipping backup node tasks as not backup node")
}
}
func runWithPanicLogging(log *slog.Logger, serviceName string, fn func()) {

View File

@@ -28,7 +28,6 @@ require (
github.com/valkey-io/valkey-go v1.0.70
go.mongodb.org/mongo-driver v1.17.6
golang.org/x/crypto v0.46.0
golang.org/x/time v0.14.0
gorm.io/driver/postgres v1.5.11
gorm.io/gorm v1.26.1
)
@@ -186,6 +185,7 @@ require (
go.yaml.in/yaml/v2 v2.4.3 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/term v0.38.0 // indirect
golang.org/x/time v0.14.0 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
gopkg.in/validator.v2 v2.0.1 // indirect
moul.io/http2curl/v2 v2.3.0 // indirect

View File

@@ -9,6 +9,7 @@ import (
"strings"
"sync"
"github.com/google/uuid"
"github.com/ilyakaznacheev/cleanenv"
"github.com/joho/godotenv"
)
@@ -29,6 +30,12 @@ type EnvVariables struct {
MariadbInstallDir string `env:"MARIADB_INSTALL_DIR"`
MongodbInstallDir string `env:"MONGODB_INSTALL_DIR"`
NodeID string
IsManyNodesMode bool `env:"IS_MANY_NODES_MODE"`
IsPrimaryNode bool `env:"IS_PRIMARY_NODE"`
IsBackupNode bool `env:"IS_BACKUP_NODE"`
NodeNetworkThroughputMBs int `env:"NODE_NETWORK_THROUGHPUT_MBPS"`
DataFolder string
TempFolder string
SecretKeyPath string
@@ -196,6 +203,16 @@ func loadEnvVariables() {
env.MongodbInstallDir = filepath.Join(backendRoot, "tools", "mongodb")
tools.VerifyMongodbInstallation(log, env.EnvMode, env.MongodbInstallDir)
env.NodeID = uuid.New().String()
if env.NodeNetworkThroughputMBs == 0 {
env.NodeNetworkThroughputMBs = 125 // 1 Gbit/s
}
if !env.IsManyNodesMode {
env.IsPrimaryNode = true
env.IsBackupNode = true
}
// Valkey
if env.ValkeyHost == "" {
log.Error("VALKEY_HOST is empty")

View File

@@ -1,7 +1,7 @@
package audit_logs
import (
"databasus-backend/internal/config"
"context"
"log/slog"
"time"
)
@@ -11,23 +11,25 @@ type AuditLogBackgroundService struct {
logger *slog.Logger
}
func (s *AuditLogBackgroundService) Run() {
func (s *AuditLogBackgroundService) Run(ctx context.Context) {
s.logger.Info("Starting audit log cleanup background service")
if config.IsShouldShutdown() {
if ctx.Err() != nil {
return
}
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for {
if config.IsShouldShutdown() {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.cleanOldAuditLogs(); err != nil {
s.logger.Error("Failed to clean old audit logs", "error", err)
}
}
if err := s.cleanOldAuditLogs(); err != nil {
s.logger.Error("Failed to clean old audit logs", "error", err)
}
time.Sleep(1 * time.Hour)
}
}

View File

@@ -1,254 +0,0 @@
package backups
import (
"databasus-backend/internal/config"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/storages"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/period"
"log/slog"
"time"
)
type BackupBackgroundService struct {
backupService *BackupService
backupRepository *BackupRepository
backupConfigService *backups_config.BackupConfigService
storageService *storages.StorageService
lastBackupTime time.Time
logger *slog.Logger
}
func (s *BackupBackgroundService) Run() {
s.lastBackupTime = time.Now().UTC()
if err := s.failBackupsInProgress(); err != nil {
s.logger.Error("Failed to fail backups in progress", "error", err)
panic(err)
}
if config.IsShouldShutdown() {
return
}
for {
if config.IsShouldShutdown() {
return
}
if err := s.cleanOldBackups(); err != nil {
s.logger.Error("Failed to clean old backups", "error", err)
}
if err := s.runPendingBackups(); err != nil {
s.logger.Error("Failed to run pending backups", "error", err)
}
s.lastBackupTime = time.Now().UTC()
time.Sleep(1 * time.Minute)
}
}
func (s *BackupBackgroundService) IsBackupsWorkerRunning() bool {
// if last backup time is more than 5 minutes ago, return false
return s.lastBackupTime.After(time.Now().UTC().Add(-5 * time.Minute))
}
func (s *BackupBackgroundService) failBackupsInProgress() error {
backupsInProgress, err := s.backupRepository.FindByStatus(BackupStatusInProgress)
if err != nil {
return err
}
for _, backup := range backupsInProgress {
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(backup.DatabaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
continue
}
failMessage := "Backup failed due to application restart"
backup.FailMessage = &failMessage
backup.Status = BackupStatusFailed
backup.BackupSizeMb = 0
s.backupService.SendBackupNotification(
backupConfig,
backup,
backups_config.NotificationBackupFailed,
&failMessage,
)
if err := s.backupRepository.Save(backup); err != nil {
return err
}
}
return nil
}
func (s *BackupBackgroundService) cleanOldBackups() error {
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
backupStorePeriod := backupConfig.StorePeriod
if backupStorePeriod == period.PeriodForever {
continue
}
storeDuration := backupStorePeriod.ToDuration()
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
oldBackups, err := s.backupRepository.FindBackupsBeforeDate(
backupConfig.DatabaseID,
dateBeforeBackupsShouldBeDeleted,
)
if err != nil {
s.logger.Error(
"Failed to find old backups for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
continue
}
for _, backup := range oldBackups {
storage, err := s.storageService.GetStorageByID(backup.StorageID)
if err != nil {
s.logger.Error(
"Failed to get storage by ID",
"storageId",
backup.StorageID,
"error",
err,
)
continue
}
encryptor := encryption.GetFieldEncryptor()
err = storage.DeleteFile(encryptor, backup.ID)
if err != nil {
s.logger.Error("Failed to delete backup file", "backupId", backup.ID, "error", err)
}
if err := s.backupRepository.DeleteByID(backup.ID); err != nil {
s.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
continue
}
s.logger.Info(
"Deleted old backup",
"backupId",
backup.ID,
"databaseId",
backupConfig.DatabaseID,
)
}
}
return nil
}
func (s *BackupBackgroundService) runPendingBackups() error {
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
if backupConfig.BackupInterval == nil {
continue
}
lastBackup, err := s.backupRepository.FindLastByDatabaseID(backupConfig.DatabaseID)
if err != nil {
s.logger.Error(
"Failed to get last backup for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
continue
}
var lastBackupTime *time.Time
if lastBackup != nil {
lastBackupTime = &lastBackup.CreatedAt
}
remainedBackupTryCount := s.GetRemainedBackupTryCount(lastBackup)
if backupConfig.BackupInterval.ShouldTriggerBackup(time.Now().UTC(), lastBackupTime) ||
remainedBackupTryCount > 0 {
s.logger.Info(
"Triggering scheduled backup",
"databaseId",
backupConfig.DatabaseID,
"intervalType",
backupConfig.BackupInterval.Interval,
)
go s.backupService.MakeBackup(backupConfig.DatabaseID, remainedBackupTryCount == 1)
s.logger.Info(
"Successfully triggered scheduled backup",
"databaseId",
backupConfig.DatabaseID,
)
}
}
return nil
}
// GetRemainedBackupTryCount returns the number of remaining backup tries for a given backup.
// If the backup is not failed or the backup config does not allow retries, it returns 0.
// If the backup is failed and the backup config allows retries, it returns the number of remaining tries.
// If the backup is failed and the backup config does not allow retries, it returns 0.
func (s *BackupBackgroundService) GetRemainedBackupTryCount(lastBackup *Backup) int {
if lastBackup == nil {
return 0
}
if lastBackup.Status != BackupStatusFailed {
return 0
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(lastBackup.DatabaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
return 0
}
if !backupConfig.IsRetryIfFailed {
return 0
}
maxFailedTriesCount := backupConfig.MaxFailedTriesCount
lastBackups, err := s.backupRepository.FindByDatabaseIDWithLimit(
lastBackup.DatabaseID,
maxFailedTriesCount,
)
if err != nil {
s.logger.Error("Failed to find last backups by database ID", "error", err)
return 0
}
lastFailedBackups := make([]*Backup, 0)
for _, backup := range lastBackups {
if backup.Status == BackupStatusFailed {
lastFailedBackups = append(lastFailedBackups, backup)
}
}
return maxFailedTriesCount - len(lastFailedBackups)
}

View File

@@ -0,0 +1,344 @@
package backuping
import (
"context"
"databasus-backend/internal/config"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/storages"
workspaces_services "databasus-backend/internal/features/workspaces/services"
util_encryption "databasus-backend/internal/util/encryption"
"errors"
"fmt"
"log/slog"
"slices"
"strings"
"time"
"github.com/google/uuid"
)
type BackuperNode struct {
databaseService *databases.DatabaseService
fieldEncryptor util_encryption.FieldEncryptor
workspaceService *workspaces_services.WorkspaceService
backupRepository *backups_core.BackupRepository
backupConfigService *backups_config.BackupConfigService
storageService *storages.StorageService
notificationSender backups_core.NotificationSender
backupCancelManager *backups_cancellation.BackupCancelManager
nodesRegistry *BackupNodesRegistry
logger *slog.Logger
createBackupUseCase backups_core.CreateBackupUsecase
nodeID uuid.UUID
lastHeartbeat time.Time
}
func (n *BackuperNode) Run(ctx context.Context) {
n.lastHeartbeat = time.Now().UTC()
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
backupNode := BackupNode{
ID: n.nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: time.Now().UTC(),
}
if err := n.nodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
n.logger.Error("Failed to register node in registry", "error", err)
panic(err)
}
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
n.MakeBackup(backupID, isCallNotifier)
if err := n.nodesRegistry.PublishBackupCompletion(n.nodeID.String(), backupID); err != nil {
n.logger.Error(
"Failed to publish backup completion",
"error",
err,
"backupID",
backupID,
)
}
}
if err := n.nodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID.String(), backupHandler); err != nil {
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
panic(err)
}
defer func() {
if err := n.nodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
}
}()
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
for {
select {
case <-ctx.Done():
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
if err := n.nodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
n.logger.Error("Failed to unregister node from registry", "error", err)
}
return
case <-ticker.C:
n.sendHeartbeat(&backupNode)
}
}
}
func (n *BackuperNode) IsBackuperRunning() bool {
return n.lastHeartbeat.After(time.Now().UTC().Add(-5 * time.Minute))
}
func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
backup, err := n.backupRepository.FindByID(backupID)
if err != nil {
n.logger.Error("Failed to get backup by ID", "backupId", backupID, "error", err)
return
}
databaseID := backup.DatabaseID
database, err := n.databaseService.GetDatabaseByID(databaseID)
if err != nil {
n.logger.Error("Failed to get database by ID", "databaseId", databaseID, "error", err)
return
}
backupConfig, err := n.backupConfigService.GetBackupConfigByDbId(databaseID)
if err != nil {
n.logger.Error("Failed to get backup config by database ID", "error", err)
return
}
if backupConfig.StorageID == nil {
n.logger.Error("Backup config storage ID is not defined")
return
}
storage, err := n.storageService.GetStorageByID(*backupConfig.StorageID)
if err != nil {
n.logger.Error("Failed to get storage by ID", "error", err)
return
}
start := time.Now().UTC()
backupProgressListener := func(
completedMBs float64,
) {
backup.BackupSizeMb = completedMBs
backup.BackupDurationMs = time.Since(start).Milliseconds()
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to update backup progress", "error", err)
}
}
ctx, cancel := context.WithCancel(context.Background())
n.backupCancelManager.RegisterBackup(backup.ID, cancel)
defer n.backupCancelManager.UnregisterBackup(backup.ID)
backupMetadata, err := n.createBackupUseCase.Execute(
ctx,
backup.ID,
backupConfig,
database,
storage,
backupProgressListener,
)
if err != nil {
errMsg := err.Error()
// Check if backup was cancelled (not due to shutdown)
isCancelled := strings.Contains(errMsg, "backup cancelled") ||
strings.Contains(errMsg, "context canceled") ||
errors.Is(err, context.Canceled)
isShutdown := strings.Contains(errMsg, "shutdown")
if isCancelled && !isShutdown {
backup.Status = backups_core.BackupStatusCanceled
backup.BackupDurationMs = time.Since(start).Milliseconds()
backup.BackupSizeMb = 0
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to save cancelled backup", "error", err)
}
// Delete partial backup from storage
storage, storageErr := n.storageService.GetStorageByID(backup.StorageID)
if storageErr == nil {
if deleteErr := storage.DeleteFile(n.fieldEncryptor, backup.ID); deleteErr != nil {
n.logger.Error(
"Failed to delete partial backup file",
"backupId",
backup.ID,
"error",
deleteErr,
)
}
}
return
}
backup.FailMessage = &errMsg
backup.Status = backups_core.BackupStatusFailed
backup.BackupDurationMs = time.Since(start).Milliseconds()
backup.BackupSizeMb = 0
if updateErr := n.databaseService.SetBackupError(databaseID, errMsg); updateErr != nil {
n.logger.Error(
"Failed to update database last backup time",
"databaseId",
databaseID,
"error",
updateErr,
)
}
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to save backup", "error", err)
}
n.SendBackupNotification(
backupConfig,
backup,
backups_config.NotificationBackupFailed,
&errMsg,
)
return
}
backup.Status = backups_core.BackupStatusCompleted
backup.BackupDurationMs = time.Since(start).Milliseconds()
// Update backup with encryption metadata if provided
if backupMetadata != nil {
backup.EncryptionSalt = backupMetadata.EncryptionSalt
backup.EncryptionIV = backupMetadata.EncryptionIV
backup.Encryption = backupMetadata.Encryption
}
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to save backup", "error", err)
return
}
// Update database last backup time
now := time.Now().UTC()
if updateErr := n.databaseService.SetLastBackupTime(databaseID, now); updateErr != nil {
n.logger.Error(
"Failed to update database last backup time",
"databaseId",
databaseID,
"error",
updateErr,
)
}
if backup.Status != backups_core.BackupStatusCompleted && !isCallNotifier {
return
}
n.SendBackupNotification(
backupConfig,
backup,
backups_config.NotificationBackupSuccess,
nil,
)
}
func (n *BackuperNode) SendBackupNotification(
backupConfig *backups_config.BackupConfig,
backup *backups_core.Backup,
notificationType backups_config.BackupNotificationType,
errorMessage *string,
) {
database, err := n.databaseService.GetDatabaseByID(backupConfig.DatabaseID)
if err != nil {
return
}
workspace, err := n.workspaceService.GetWorkspaceByID(*database.WorkspaceID)
if err != nil {
return
}
for _, notifier := range database.Notifiers {
if !slices.Contains(
backupConfig.SendNotificationsOn,
notificationType,
) {
continue
}
title := ""
switch notificationType {
case backups_config.NotificationBackupFailed:
title = fmt.Sprintf(
"❌ Backup failed for database \"%s\" (workspace \"%s\")",
database.Name,
workspace.Name,
)
case backups_config.NotificationBackupSuccess:
title = fmt.Sprintf(
"✅ Backup completed for database \"%s\" (workspace \"%s\")",
database.Name,
workspace.Name,
)
}
message := ""
if errorMessage != nil {
message = *errorMessage
} else {
// Format size conditionally
var sizeStr string
if backup.BackupSizeMb < 1024 {
sizeStr = fmt.Sprintf("%.2f MB", backup.BackupSizeMb)
} else {
sizeGB := backup.BackupSizeMb / 1024
sizeStr = fmt.Sprintf("%.2f GB", sizeGB)
}
// Format duration as "0m 0s 0ms"
totalMs := backup.BackupDurationMs
minutes := totalMs / (1000 * 60)
seconds := (totalMs % (1000 * 60)) / 1000
durationStr := fmt.Sprintf("%dm %ds", minutes, seconds)
message = fmt.Sprintf(
"Backup completed successfully in %s.\nCompressed backup size: %s",
durationStr,
sizeStr,
)
}
n.notificationSender.SendNotification(
&notifier,
title,
message,
)
}
}
func (n *BackuperNode) sendHeartbeat(backupNode *BackupNode) {
n.lastHeartbeat = time.Now().UTC()
backupNode.LastHeartbeat = time.Now().UTC()
if err := n.nodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *backupNode); err != nil {
n.logger.Error("Failed to send heartbeat", "error", err)
}
}

View File

@@ -1,4 +1,4 @@
package backups
package backuping
import (
"context"
@@ -8,17 +8,15 @@ import (
"time"
common "databasus-backend/internal/features/backups/backups/common"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_services "databasus-backend/internal/features/workspaces/services"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
cache_utils "databasus-backend/internal/util/cache"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
@@ -26,6 +24,7 @@ import (
)
func Test_BackupExecuted_NotificationSent(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
@@ -50,23 +49,19 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
t.Run("BackupFailed_FailNotificationSent", func(t *testing.T) {
mockNotificationSender := &MockNotificationSender{}
backupService := &BackupService{
databases.GetDatabaseService(),
storages.GetStorageService(),
backupRepository,
notifiers.GetNotifierService(),
mockNotificationSender,
backups_config.GetBackupConfigService(),
encryption_secrets.GetSecretKeyService(),
encryption.GetFieldEncryptor(),
&CreateFailedBackupUsecase{},
logger.GetLogger(),
[]BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
nil,
NewBackupContextManager(),
nil,
backuperNode := CreateTestBackuperNode()
backuperNode.notificationSender = mockNotificationSender
backuperNode.createBackupUseCase = &CreateFailedBackupUsecase{}
// Create a backup record directly that will be looked up by MakeBackup
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err := backupRepository.Save(backup)
assert.NoError(t, err)
// Set up expectations
mockNotificationSender.On("SendNotification",
@@ -79,7 +74,7 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
}),
).Once()
backupService.MakeBackup(database.ID, true)
backuperNode.MakeBackup(backup.ID, true)
// Verify all expectations were met
mockNotificationSender.AssertExpectations(t)
@@ -87,6 +82,19 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
t.Run("BackupSuccess_SuccessNotificationSent", func(t *testing.T) {
mockNotificationSender := &MockNotificationSender{}
backuperNode := CreateTestBackuperNode()
backuperNode.notificationSender = mockNotificationSender
backuperNode.createBackupUseCase = &CreateSuccessBackupUsecase{}
// Create a backup record directly that will be looked up by MakeBackup
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err := backupRepository.Save(backup)
assert.NoError(t, err)
// Set up expectations
mockNotificationSender.On("SendNotification",
@@ -99,25 +107,7 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
}),
).Once()
backupService := &BackupService{
databases.GetDatabaseService(),
storages.GetStorageService(),
backupRepository,
notifiers.GetNotifierService(),
mockNotificationSender,
backups_config.GetBackupConfigService(),
encryption_secrets.GetSecretKeyService(),
encryption.GetFieldEncryptor(),
&CreateSuccessBackupUsecase{},
logger.GetLogger(),
[]BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
nil,
NewBackupContextManager(),
nil,
}
backupService.MakeBackup(database.ID, true)
backuperNode.MakeBackup(backup.ID, true)
// Verify all expectations were met
mockNotificationSender.AssertExpectations(t)
@@ -125,23 +115,19 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
t.Run("BackupSuccess_VerifyNotificationContent", func(t *testing.T) {
mockNotificationSender := &MockNotificationSender{}
backupService := &BackupService{
databases.GetDatabaseService(),
storages.GetStorageService(),
backupRepository,
notifiers.GetNotifierService(),
mockNotificationSender,
backups_config.GetBackupConfigService(),
encryption_secrets.GetSecretKeyService(),
encryption.GetFieldEncryptor(),
&CreateSuccessBackupUsecase{},
logger.GetLogger(),
[]BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
nil,
NewBackupContextManager(),
nil,
backuperNode := CreateTestBackuperNode()
backuperNode.notificationSender = mockNotificationSender
backuperNode.createBackupUseCase = &CreateSuccessBackupUsecase{}
// Create a backup record directly that will be looked up by MakeBackup
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err := backupRepository.Save(backup)
assert.NoError(t, err)
// capture arguments
var capturedNotifier *notifiers.Notifier
@@ -158,7 +144,7 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
capturedMessage = args.Get(2).(string)
}).Once()
backupService.MakeBackup(database.ID, true)
backuperNode.MakeBackup(backup.ID, true)
// Verify expectations were met
mockNotificationSender.AssertExpectations(t)

View File

@@ -0,0 +1,77 @@
package backuping
import (
"databasus-backend/internal/config"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/usecases"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
workspaces_services "databasus-backend/internal/features/workspaces/services"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
"time"
"github.com/google/uuid"
)
var backupRepository = &backups_core.BackupRepository{}
var backupCancelManager = backups_cancellation.GetBackupCancelManager()
var nodesRegistry = &BackupNodesRegistry{
cache_utils.GetValkeyClient(),
logger.GetLogger(),
cache_utils.DefaultCacheTimeout,
cache_utils.NewPubSubManager(),
cache_utils.NewPubSubManager(),
}
func getNodeID() uuid.UUID {
nodeIDStr := config.GetEnv().NodeID
nodeID, err := uuid.Parse(nodeIDStr)
if err != nil {
logger.GetLogger().Error("Failed to parse node ID from config", "error", err)
panic(err)
}
return nodeID
}
var backuperNode = &BackuperNode{
databases.GetDatabaseService(),
encryption.GetFieldEncryptor(),
workspaces_services.GetWorkspaceService(),
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
notifiers.GetNotifierService(),
backupCancelManager,
nodesRegistry,
logger.GetLogger(),
usecases.GetCreateBackupUsecase(),
getNodeID(),
time.Time{},
}
var backupsScheduler = &BackupsScheduler{
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
backupCancelManager,
nodesRegistry,
time.Now().UTC(),
logger.GetLogger(),
make(map[uuid.UUID]BackupToNodeRelation),
backuperNode,
}
func GetBackupsScheduler() *BackupsScheduler {
return backupsScheduler
}
func GetBackuperNode() *BackuperNode {
return backuperNode
}

View File

@@ -0,0 +1,34 @@
package backuping
import (
"time"
"github.com/google/uuid"
)
type BackupNode struct {
ID uuid.UUID `json:"id"`
ThroughputMBs int `json:"throughputMBs"`
LastHeartbeat time.Time `json:"lastHeartbeat"`
}
type BackupNodeStats struct {
ID uuid.UUID `json:"id"`
ActiveBackups int `json:"activeBackups"`
}
type BackupSubmitMessage struct {
NodeID string `json:"nodeId"`
BackupID string `json:"backupId"`
IsCallNotifier bool `json:"isCallNotifier"`
}
type BackupCompletionMessage struct {
NodeID string `json:"nodeId"`
BackupID string `json:"backupId"`
}
type BackupToNodeRelation struct {
NodeID uuid.UUID `json:"nodeId"`
BackupsIDs []uuid.UUID `json:"backupsIds"`
}

View File

@@ -1,4 +1,4 @@
package backups
package backuping
import (
"databasus-backend/internal/features/notifiers"

View File

@@ -0,0 +1,448 @@
package backuping
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"time"
cache_utils "databasus-backend/internal/util/cache"
"github.com/google/uuid"
"github.com/valkey-io/valkey-go"
)
const (
nodeInfoKeyPrefix = "node:"
nodeInfoKeySuffix = ":info"
nodeActiveBackupsPrefix = "node:"
nodeActiveBackupsSuffix = ":active_backups"
backupSubmitChannel = "backup:submit"
backupCompletionChannel = "backup:completion"
)
// BackupNodesRegistry helps to sync backups scheduler and backup nodes.
// Features:
// - Track node availability and load level
// - Assign from scheduler to node backups needed to be processed
// - Notify scheduler from node about backup completion
type BackupNodesRegistry struct {
client valkey.Client
logger *slog.Logger
timeout time.Duration
pubsubBackups *cache_utils.PubSubManager
pubsubCompletions *cache_utils.PubSubManager
}
func (r *BackupNodesRegistry) GetAvailableNodes() ([]BackupNode, error) {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
var allKeys []string
cursor := uint64(0)
pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix
for {
result := r.client.Do(
ctx,
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(),
)
if result.Error() != nil {
return nil, fmt.Errorf("failed to scan node keys: %w", result.Error())
}
scanResult, err := result.AsScanEntry()
if err != nil {
return nil, fmt.Errorf("failed to parse scan result: %w", err)
}
allKeys = append(allKeys, scanResult.Elements...)
cursor = scanResult.Cursor
if cursor == 0 {
break
}
}
if len(allKeys) == 0 {
return []BackupNode{}, nil
}
keyDataMap, err := r.pipelineGetKeys(allKeys)
if err != nil {
return nil, fmt.Errorf("failed to pipeline get node keys: %w", err)
}
var nodes []BackupNode
for key, data := range keyDataMap {
var node BackupNode
if err := json.Unmarshal(data, &node); err != nil {
r.logger.Warn("Failed to unmarshal node data", "key", key, "error", err)
continue
}
nodes = append(nodes, node)
}
return nodes, nil
}
func (r *BackupNodesRegistry) GetBackupNodesStats() ([]BackupNodeStats, error) {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
var allKeys []string
cursor := uint64(0)
pattern := nodeActiveBackupsPrefix + "*" + nodeActiveBackupsSuffix
for {
result := r.client.Do(
ctx,
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(100).Build(),
)
if result.Error() != nil {
return nil, fmt.Errorf("failed to scan active backups keys: %w", result.Error())
}
scanResult, err := result.AsScanEntry()
if err != nil {
return nil, fmt.Errorf("failed to parse scan result: %w", err)
}
allKeys = append(allKeys, scanResult.Elements...)
cursor = scanResult.Cursor
if cursor == 0 {
break
}
}
if len(allKeys) == 0 {
return []BackupNodeStats{}, nil
}
keyDataMap, err := r.pipelineGetKeys(allKeys)
if err != nil {
return nil, fmt.Errorf("failed to pipeline get active backups keys: %w", err)
}
var stats []BackupNodeStats
for key, data := range keyDataMap {
nodeID := r.extractNodeIDFromKey(key, nodeActiveBackupsPrefix, nodeActiveBackupsSuffix)
count, err := r.parseIntFromBytes(data)
if err != nil {
r.logger.Warn("Failed to parse active backups count", "key", key, "error", err)
continue
}
stat := BackupNodeStats{
ID: nodeID,
ActiveBackups: int(count),
}
stats = append(stats, stat)
}
return stats, nil
}
func (r *BackupNodesRegistry) IncrementBackupsInProgress(nodeID string) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID, nodeActiveBackupsSuffix)
result := r.client.Do(ctx, r.client.B().Incr().Key(key).Build())
if result.Error() != nil {
return fmt.Errorf(
"failed to increment backups in progress for node %s: %w",
nodeID,
result.Error(),
)
}
return nil
}
func (r *BackupNodesRegistry) DecrementBackupsInProgress(nodeID string) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID, nodeActiveBackupsSuffix)
result := r.client.Do(ctx, r.client.B().Decr().Key(key).Build())
if result.Error() != nil {
return fmt.Errorf(
"failed to decrement backups in progress for node %s: %w",
nodeID,
result.Error(),
)
}
newValue, err := result.AsInt64()
if err != nil {
return fmt.Errorf("failed to parse decremented value for node %s: %w", nodeID, err)
}
if newValue < 0 {
setCtx, setCancel := context.WithTimeout(context.Background(), r.timeout)
r.client.Do(setCtx, r.client.B().Set().Key(key).Value("0").Build())
setCancel()
r.logger.Warn("Active backups counter went below 0, reset to 0", "nodeID", nodeID)
}
return nil
}
func (r *BackupNodesRegistry) HearthbeatNodeInRegistry(now time.Time, backupNode BackupNode) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
backupNode.LastHeartbeat = now
data, err := json.Marshal(backupNode)
if err != nil {
return fmt.Errorf("failed to marshal backup node: %w", err)
}
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, backupNode.ID.String(), nodeInfoKeySuffix)
result := r.client.Do(
ctx,
r.client.B().Set().Key(key).Value(string(data)).Build(),
)
if result.Error() != nil {
return fmt.Errorf("failed to register node %s: %w", backupNode.ID, result.Error())
}
return nil
}
func (r *BackupNodesRegistry) UnregisterNodeFromRegistry(backupNode BackupNode) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, backupNode.ID.String(), nodeInfoKeySuffix)
counterKey := fmt.Sprintf(
"%s%s%s",
nodeActiveBackupsPrefix,
backupNode.ID.String(),
nodeActiveBackupsSuffix,
)
result := r.client.Do(
ctx,
r.client.B().Del().Key(infoKey, counterKey).Build(),
)
if result.Error() != nil {
return fmt.Errorf("failed to unregister node %s: %w", backupNode.ID, result.Error())
}
r.logger.Info("Unregistered node from registry", "nodeID", backupNode.ID)
return nil
}
func (r *BackupNodesRegistry) AssignBackupToNode(
targetNodeID string,
backupID uuid.UUID,
isCallNotifier bool,
) error {
ctx := context.Background()
message := BackupSubmitMessage{
NodeID: targetNodeID,
BackupID: backupID.String(),
IsCallNotifier: isCallNotifier,
}
messageJSON, err := json.Marshal(message)
if err != nil {
return fmt.Errorf("failed to marshal backup submit message: %w", err)
}
err = r.pubsubBackups.Publish(ctx, backupSubmitChannel, string(messageJSON))
if err != nil {
return fmt.Errorf("failed to publish backup submit message: %w", err)
}
return nil
}
func (r *BackupNodesRegistry) SubscribeNodeForBackupsAssignment(
nodeID string,
handler func(backupID uuid.UUID, isCallNotifier bool),
) error {
ctx := context.Background()
wrappedHandler := func(message string) {
var msg BackupSubmitMessage
if err := json.Unmarshal([]byte(message), &msg); err != nil {
r.logger.Warn("Failed to unmarshal backup submit message", "error", err)
return
}
if msg.NodeID != nodeID {
return
}
backupID, err := uuid.Parse(msg.BackupID)
if err != nil {
r.logger.Warn(
"Failed to parse backup ID from message",
"backupId",
msg.BackupID,
"error",
err,
)
return
}
handler(backupID, msg.IsCallNotifier)
}
err := r.pubsubBackups.Subscribe(ctx, backupSubmitChannel, wrappedHandler)
if err != nil {
return fmt.Errorf("failed to subscribe to backup submit channel: %w", err)
}
r.logger.Info("Subscribed to backup submit channel", "nodeID", nodeID)
return nil
}
func (r *BackupNodesRegistry) UnsubscribeNodeForBackupsAssignments() error {
err := r.pubsubBackups.Close()
if err != nil {
return fmt.Errorf("failed to unsubscribe from backup submit channel: %w", err)
}
r.logger.Info("Unsubscribed from backup submit channel")
return nil
}
func (r *BackupNodesRegistry) PublishBackupCompletion(nodeID string, backupID uuid.UUID) error {
ctx := context.Background()
message := BackupCompletionMessage{
NodeID: nodeID,
BackupID: backupID.String(),
}
messageJSON, err := json.Marshal(message)
if err != nil {
return fmt.Errorf("failed to marshal backup completion message: %w", err)
}
err = r.pubsubCompletions.Publish(ctx, backupCompletionChannel, string(messageJSON))
if err != nil {
return fmt.Errorf("failed to publish backup completion message: %w", err)
}
return nil
}
func (r *BackupNodesRegistry) SubscribeForBackupsCompletions(
handler func(nodeID string, backupID uuid.UUID),
) error {
ctx := context.Background()
wrappedHandler := func(message string) {
var msg BackupCompletionMessage
if err := json.Unmarshal([]byte(message), &msg); err != nil {
r.logger.Warn("Failed to unmarshal backup completion message", "error", err)
return
}
backupID, err := uuid.Parse(msg.BackupID)
if err != nil {
r.logger.Warn(
"Failed to parse backup ID from completion message",
"backupId",
msg.BackupID,
"error",
err,
)
return
}
handler(msg.NodeID, backupID)
}
err := r.pubsubCompletions.Subscribe(ctx, backupCompletionChannel, wrappedHandler)
if err != nil {
return fmt.Errorf("failed to subscribe to backup completion channel: %w", err)
}
r.logger.Info("Subscribed to backup completion channel")
return nil
}
func (r *BackupNodesRegistry) UnsubscribeForBackupsCompletions() error {
err := r.pubsubCompletions.Close()
if err != nil {
return fmt.Errorf("failed to unsubscribe from backup completion channel: %w", err)
}
r.logger.Info("Unsubscribed from backup completion channel")
return nil
}
func (r *BackupNodesRegistry) extractNodeIDFromKey(key, prefix, suffix string) uuid.UUID {
nodeIDStr := strings.TrimPrefix(key, prefix)
nodeIDStr = strings.TrimSuffix(nodeIDStr, suffix)
nodeID, err := uuid.Parse(nodeIDStr)
if err != nil {
r.logger.Warn("Failed to parse node ID from key", "key", key, "error", err)
return uuid.Nil
}
return nodeID
}
func (r *BackupNodesRegistry) pipelineGetKeys(keys []string) (map[string][]byte, error) {
if len(keys) == 0 {
return make(map[string][]byte), nil
}
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
commands := make([]valkey.Completed, 0, len(keys))
for _, key := range keys {
commands = append(commands, r.client.B().Get().Key(key).Build())
}
results := r.client.DoMulti(ctx, commands...)
keyDataMap := make(map[string][]byte, len(keys))
for i, result := range results {
if result.Error() != nil {
r.logger.Warn("Failed to get key in pipeline", "key", keys[i], "error", result.Error())
continue
}
data, err := result.AsBytes()
if err != nil {
r.logger.Warn("Failed to parse key data in pipeline", "key", keys[i], "error", err)
continue
}
keyDataMap[keys[i]] = data
}
return keyDataMap, nil
}
func (r *BackupNodesRegistry) parseIntFromBytes(data []byte) (int64, error) {
str := string(data)
var count int64
_, err := fmt.Sscanf(str, "%d", &count)
if err != nil {
return 0, fmt.Errorf("failed to parse integer from bytes: %w", err)
}
return count, nil
}

View File

@@ -0,0 +1,904 @@
package backuping
import (
"context"
"testing"
"time"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/logger"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_HearthbeatNodeInRegistry_RegistersNodeWithTTL(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
defer cleanupTestNode(registry, node)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 1)
assert.Equal(t, node.ID, nodes[0].ID)
assert.Equal(t, node.ThroughputMBs, nodes[0].ThroughputMBs)
}
func Test_UnregisterNodeFromRegistry_RemovesNodeAndCounter(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
err = registry.UnregisterNodeFromRegistry(node)
assert.NoError(t, err)
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.Empty(t, nodes)
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Empty(t, stats)
}
func Test_GetAvailableNodes_ReturnsAllRegisteredNodes(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node1 := createTestBackupNode()
node2 := createTestBackupNode()
node3 := createTestBackupNode()
defer cleanupTestNode(registry, node1)
defer cleanupTestNode(registry, node2)
defer cleanupTestNode(registry, node3)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
assert.NoError(t, err)
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 3)
nodeIDs := make(map[uuid.UUID]bool)
for _, node := range nodes {
nodeIDs[node.ID] = true
}
assert.True(t, nodeIDs[node1.ID])
assert.True(t, nodeIDs[node2.ID])
assert.True(t, nodeIDs[node3.ID])
}
func Test_GetAvailableNodes_WhenNoNodesExist_ReturnsEmptySlice(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.NotNil(t, nodes)
assert.Empty(t, nodes)
}
func Test_IncrementBackupsInProgress_IncrementsCounter(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
defer cleanupTestNode(registry, node)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Len(t, stats, 1)
assert.Equal(t, node.ID, stats[0].ID)
assert.Equal(t, 1, stats[0].ActiveBackups)
err = registry.IncrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
stats, err = registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Len(t, stats, 1)
assert.Equal(t, 2, stats[0].ActiveBackups)
}
func Test_DecrementBackupsInProgress_DecrementsCounter(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
defer cleanupTestNode(registry, node)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Equal(t, 3, stats[0].ActiveBackups)
err = registry.DecrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
stats, err = registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Equal(t, 2, stats[0].ActiveBackups)
err = registry.DecrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
stats, err = registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Equal(t, 1, stats[0].ActiveBackups)
}
func Test_DecrementBackupsInProgress_WhenNegative_ResetsToZero(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
defer cleanupTestNode(registry, node)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
err = registry.DecrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Len(t, stats, 1)
assert.Equal(t, 0, stats[0].ActiveBackups)
}
func Test_GetBackupNodesStats_ReturnsStatsForAllNodes(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node1 := createTestBackupNode()
node2 := createTestBackupNode()
node3 := createTestBackupNode()
defer cleanupTestNode(registry, node1)
defer cleanupTestNode(registry, node2)
defer cleanupTestNode(registry, node3)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node1.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node2.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node2.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node3.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node3.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node3.ID.String())
assert.NoError(t, err)
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Len(t, stats, 3)
statsMap := make(map[uuid.UUID]int)
for _, stat := range stats {
statsMap[stat.ID] = stat.ActiveBackups
}
assert.Equal(t, 1, statsMap[node1.ID])
assert.Equal(t, 2, statsMap[node2.ID])
assert.Equal(t, 3, statsMap[node3.ID])
}
func Test_GetBackupNodesStats_WhenNoStats_ReturnsEmptySlice(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.NotNil(t, stats)
assert.Empty(t, stats)
}
func Test_MultipleNodes_RegisteredAndQueriedCorrectly(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node1 := createTestBackupNode()
node1.ThroughputMBs = 50
node2 := createTestBackupNode()
node2.ThroughputMBs = 100
node3 := createTestBackupNode()
node3.ThroughputMBs = 150
defer cleanupTestNode(registry, node1)
defer cleanupTestNode(registry, node2)
defer cleanupTestNode(registry, node3)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
assert.NoError(t, err)
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 3)
nodeMap := make(map[uuid.UUID]BackupNode)
for _, node := range nodes {
nodeMap[node.ID] = node
}
assert.Equal(t, 50, nodeMap[node1.ID].ThroughputMBs)
assert.Equal(t, 100, nodeMap[node2.ID].ThroughputMBs)
assert.Equal(t, 150, nodeMap[node3.ID].ThroughputMBs)
}
func Test_BackupCounters_TrackedSeparatelyPerNode(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node1 := createTestBackupNode()
node2 := createTestBackupNode()
defer cleanupTestNode(registry, node1)
defer cleanupTestNode(registry, node2)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node1.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node1.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node2.ID.String())
assert.NoError(t, err)
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Len(t, stats, 2)
statsMap := make(map[uuid.UUID]int)
for _, stat := range stats {
statsMap[stat.ID] = stat.ActiveBackups
}
assert.Equal(t, 2, statsMap[node1.ID])
assert.Equal(t, 1, statsMap[node2.ID])
err = registry.DecrementBackupsInProgress(node1.ID.String())
assert.NoError(t, err)
stats, err = registry.GetBackupNodesStats()
assert.NoError(t, err)
statsMap = make(map[uuid.UUID]int)
for _, stat := range stats {
statsMap[stat.ID] = stat.ActiveBackups
}
assert.Equal(t, 1, statsMap[node1.ID])
assert.Equal(t, 1, statsMap[node2.ID])
}
func Test_GetAvailableNodes_SkipsInvalidJsonData(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
defer cleanupTestNode(registry, node)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
defer cancel()
invalidKey := nodeInfoKeyPrefix + uuid.New().String() + nodeInfoKeySuffix
registry.client.Do(
ctx,
registry.client.B().Set().Key(invalidKey).Value("invalid json data").Build(),
)
defer func() {
cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), registry.timeout)
defer cleanupCancel()
registry.client.Do(cleanupCtx, registry.client.B().Del().Key(invalidKey).Build())
}()
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 1)
assert.Equal(t, node.ID, nodes[0].ID)
}
func Test_PipelineGetKeys_HandlesEmptyKeysList(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
keyDataMap, err := registry.pipelineGetKeys([]string{})
assert.NoError(t, err)
assert.NotNil(t, keyDataMap)
assert.Empty(t, keyDataMap)
}
func Test_HearthbeatNodeInRegistry_UpdatesLastHeartbeat(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
originalHeartbeat := node.LastHeartbeat
defer cleanupTestNode(registry, node)
time.Sleep(10 * time.Millisecond)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 1)
assert.True(t, nodes[0].LastHeartbeat.After(originalHeartbeat))
}
func createTestRegistry() *BackupNodesRegistry {
return &BackupNodesRegistry{
cache_utils.GetValkeyClient(),
logger.GetLogger(),
cache_utils.DefaultCacheTimeout,
cache_utils.NewPubSubManager(),
cache_utils.NewPubSubManager(),
}
}
func createTestBackupNode() BackupNode {
return BackupNode{
ID: uuid.New(),
ThroughputMBs: 100,
LastHeartbeat: time.Now().UTC(),
}
}
func cleanupTestNode(registry *BackupNodesRegistry, node BackupNode) {
registry.UnregisterNodeFromRegistry(node)
}
func Test_AssignBackupTonode_PublishesJsonMessageToChannel(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID := uuid.New()
err := registry.AssignBackupToNode(node.ID.String(), backupID, true)
assert.NoError(t, err)
}
func Test_SubscribeNodeForBackupsAssignment_ReceivesSubmittedBackupsForMatchingNode(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID := uuid.New()
defer registry.UnsubscribeNodeForBackupsAssignments()
receivedBackupID := make(chan uuid.UUID, 1)
handler := func(id uuid.UUID, isCallNotifier bool) {
receivedBackupID <- id
}
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.AssignBackupToNode(node.ID.String(), backupID, true)
assert.NoError(t, err)
select {
case received := <-receivedBackupID:
assert.Equal(t, backupID, received)
case <-time.After(2 * time.Second):
t.Fatal("Timeout waiting for backup message")
}
}
func Test_SubscribeNodeForBackupsAssignment_FiltersOutBackupsForDifferentNode(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node1 := createTestBackupNode()
node2 := createTestBackupNode()
backupID := uuid.New()
defer registry.UnsubscribeNodeForBackupsAssignments()
receivedBackupID := make(chan uuid.UUID, 1)
handler := func(id uuid.UUID, isCallNotifier bool) {
receivedBackupID <- id
}
err := registry.SubscribeNodeForBackupsAssignment(node1.ID.String(), handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.AssignBackupToNode(node2.ID.String(), backupID, false)
assert.NoError(t, err)
select {
case <-receivedBackupID:
t.Fatal("Should not receive backup for different node")
case <-time.After(500 * time.Millisecond):
}
}
func Test_SubscribeNodeForBackupsAssignment_ParsesJsonAndBackupIdCorrectly(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID1 := uuid.New()
backupID2 := uuid.New()
defer registry.UnsubscribeNodeForBackupsAssignments()
receivedBackups := make(chan uuid.UUID, 2)
handler := func(id uuid.UUID, isCallNotifier bool) {
receivedBackups <- id
}
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.AssignBackupToNode(node.ID.String(), backupID1, true)
assert.NoError(t, err)
err = registry.AssignBackupToNode(node.ID.String(), backupID2, false)
assert.NoError(t, err)
received1 := <-receivedBackups
received2 := <-receivedBackups
receivedIDs := []uuid.UUID{received1, received2}
assert.Contains(t, receivedIDs, backupID1)
assert.Contains(t, receivedIDs, backupID2)
}
func Test_SubscribeNodeForBackupsAssignment_HandlesInvalidJson(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
defer registry.UnsubscribeNodeForBackupsAssignments()
receivedBackupID := make(chan uuid.UUID, 1)
handler := func(id uuid.UUID, isCallNotifier bool) {
receivedBackupID <- id
}
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
ctx := context.Background()
err = registry.pubsubBackups.Publish(ctx, "backup:submit", "invalid json")
assert.NoError(t, err)
select {
case <-receivedBackupID:
t.Fatal("Should not receive backup for invalid JSON")
case <-time.After(500 * time.Millisecond):
}
}
func Test_UnsubscribeNodeForBackupsAssignments_StopsReceivingMessages(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID1 := uuid.New()
backupID2 := uuid.New()
receivedBackupID := make(chan uuid.UUID, 2)
handler := func(id uuid.UUID, isCallNotifier bool) {
receivedBackupID <- id
}
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.AssignBackupToNode(node.ID.String(), backupID1, true)
assert.NoError(t, err)
received := <-receivedBackupID
assert.Equal(t, backupID1, received)
err = registry.UnsubscribeNodeForBackupsAssignments()
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.AssignBackupToNode(node.ID.String(), backupID2, false)
assert.NoError(t, err)
select {
case <-receivedBackupID:
t.Fatal("Should not receive backup after unsubscribe")
case <-time.After(500 * time.Millisecond):
}
}
func Test_SubscribeNodeForBackupsAssignment_WhenAlreadySubscribed_ReturnsError(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
defer registry.UnsubscribeNodeForBackupsAssignments()
handler := func(id uuid.UUID, isCallNotifier bool) {}
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
assert.NoError(t, err)
err = registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
assert.Error(t, err)
assert.Contains(t, err.Error(), "already subscribed")
}
func Test_MultipleNodes_EachReceivesOnlyTheirBackups(t *testing.T) {
cache_utils.ClearAllCache()
registry1 := createTestRegistry()
registry2 := createTestRegistry()
registry3 := createTestRegistry()
node1 := createTestBackupNode()
node2 := createTestBackupNode()
node3 := createTestBackupNode()
backupID1 := uuid.New()
backupID2 := uuid.New()
backupID3 := uuid.New()
defer registry1.UnsubscribeNodeForBackupsAssignments()
defer registry2.UnsubscribeNodeForBackupsAssignments()
defer registry3.UnsubscribeNodeForBackupsAssignments()
receivedBackups1 := make(chan uuid.UUID, 3)
receivedBackups2 := make(chan uuid.UUID, 3)
receivedBackups3 := make(chan uuid.UUID, 3)
handler1 := func(id uuid.UUID, isCallNotifier bool) { receivedBackups1 <- id }
handler2 := func(id uuid.UUID, isCallNotifier bool) { receivedBackups2 <- id }
handler3 := func(id uuid.UUID, isCallNotifier bool) { receivedBackups3 <- id }
err := registry1.SubscribeNodeForBackupsAssignment(node1.ID.String(), handler1)
assert.NoError(t, err)
err = registry2.SubscribeNodeForBackupsAssignment(node2.ID.String(), handler2)
assert.NoError(t, err)
err = registry3.SubscribeNodeForBackupsAssignment(node3.ID.String(), handler3)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
submitRegistry := createTestRegistry()
err = submitRegistry.AssignBackupToNode(node1.ID.String(), backupID1, true)
assert.NoError(t, err)
err = submitRegistry.AssignBackupToNode(node2.ID.String(), backupID2, false)
assert.NoError(t, err)
err = submitRegistry.AssignBackupToNode(node3.ID.String(), backupID3, true)
assert.NoError(t, err)
select {
case received := <-receivedBackups1:
assert.Equal(t, backupID1, received)
case <-time.After(2 * time.Second):
t.Fatal("Node 1 timeout waiting for backup message")
}
select {
case received := <-receivedBackups2:
assert.Equal(t, backupID2, received)
case <-time.After(2 * time.Second):
t.Fatal("Node 2 timeout waiting for backup message")
}
select {
case received := <-receivedBackups3:
assert.Equal(t, backupID3, received)
case <-time.After(2 * time.Second):
t.Fatal("Node 3 timeout waiting for backup message")
}
select {
case <-receivedBackups1:
t.Fatal("Node 1 should not receive additional backups")
case <-time.After(300 * time.Millisecond):
}
select {
case <-receivedBackups2:
t.Fatal("Node 2 should not receive additional backups")
case <-time.After(300 * time.Millisecond):
}
select {
case <-receivedBackups3:
t.Fatal("Node 3 should not receive additional backups")
case <-time.After(300 * time.Millisecond):
}
}
func Test_PublishBackupCompletion_PublishesMessageToChannel(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID := uuid.New()
err := registry.PublishBackupCompletion(node.ID.String(), backupID)
assert.NoError(t, err)
}
func Test_SubscribeForBackupsCompletions_ReceivesCompletedBackups(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID := uuid.New()
defer registry.UnsubscribeForBackupsCompletions()
receivedBackupID := make(chan uuid.UUID, 1)
receivedNodeID := make(chan string, 1)
handler := func(nodeID string, backupID uuid.UUID) {
receivedNodeID <- nodeID
receivedBackupID <- backupID
}
err := registry.SubscribeForBackupsCompletions(handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.PublishBackupCompletion(node.ID.String(), backupID)
assert.NoError(t, err)
select {
case receivedNode := <-receivedNodeID:
assert.Equal(t, node.ID.String(), receivedNode)
case <-time.After(2 * time.Second):
t.Fatal("Timeout waiting for node ID")
}
select {
case received := <-receivedBackupID:
assert.Equal(t, backupID, received)
case <-time.After(2 * time.Second):
t.Fatal("Timeout waiting for backup completion message")
}
}
func Test_SubscribeForBackupsCompletions_ParsesJsonCorrectly(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID1 := uuid.New()
backupID2 := uuid.New()
defer registry.UnsubscribeForBackupsCompletions()
receivedBackups := make(chan uuid.UUID, 2)
handler := func(nodeID string, backupID uuid.UUID) {
receivedBackups <- backupID
}
err := registry.SubscribeForBackupsCompletions(handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.PublishBackupCompletion(node.ID.String(), backupID1)
assert.NoError(t, err)
err = registry.PublishBackupCompletion(node.ID.String(), backupID2)
assert.NoError(t, err)
received1 := <-receivedBackups
received2 := <-receivedBackups
receivedIDs := []uuid.UUID{received1, received2}
assert.Contains(t, receivedIDs, backupID1)
assert.Contains(t, receivedIDs, backupID2)
}
func Test_SubscribeForBackupsCompletions_HandlesInvalidJson(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
defer registry.UnsubscribeForBackupsCompletions()
receivedBackupID := make(chan uuid.UUID, 1)
handler := func(nodeID string, backupID uuid.UUID) {
receivedBackupID <- backupID
}
err := registry.SubscribeForBackupsCompletions(handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
ctx := context.Background()
err = registry.pubsubCompletions.Publish(ctx, "backup:completion", "invalid json")
assert.NoError(t, err)
select {
case <-receivedBackupID:
t.Fatal("Should not receive backup for invalid JSON")
case <-time.After(500 * time.Millisecond):
}
}
func Test_UnsubscribeForBackupsCompletions_StopsReceivingMessages(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID1 := uuid.New()
backupID2 := uuid.New()
receivedBackupID := make(chan uuid.UUID, 2)
handler := func(nodeID string, backupID uuid.UUID) {
receivedBackupID <- backupID
}
err := registry.SubscribeForBackupsCompletions(handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.PublishBackupCompletion(node.ID.String(), backupID1)
assert.NoError(t, err)
received := <-receivedBackupID
assert.Equal(t, backupID1, received)
err = registry.UnsubscribeForBackupsCompletions()
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.PublishBackupCompletion(node.ID.String(), backupID2)
assert.NoError(t, err)
select {
case <-receivedBackupID:
t.Fatal("Should not receive backup after unsubscribe")
case <-time.After(500 * time.Millisecond):
}
}
func Test_SubscribeForBackupsCompletions_WhenAlreadySubscribed_ReturnsError(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
defer registry.UnsubscribeForBackupsCompletions()
handler := func(nodeID string, backupID uuid.UUID) {}
err := registry.SubscribeForBackupsCompletions(handler)
assert.NoError(t, err)
err = registry.SubscribeForBackupsCompletions(handler)
assert.Error(t, err)
assert.Contains(t, err.Error(), "already subscribed")
}
func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) {
cache_utils.ClearAllCache()
registry1 := createTestRegistry()
registry2 := createTestRegistry()
registry3 := createTestRegistry()
node1 := createTestBackupNode()
node2 := createTestBackupNode()
node3 := createTestBackupNode()
backupID1 := uuid.New()
backupID2 := uuid.New()
backupID3 := uuid.New()
defer registry1.UnsubscribeForBackupsCompletions()
defer registry2.UnsubscribeForBackupsCompletions()
defer registry3.UnsubscribeForBackupsCompletions()
receivedBackups1 := make(chan uuid.UUID, 3)
receivedBackups2 := make(chan uuid.UUID, 3)
receivedBackups3 := make(chan uuid.UUID, 3)
handler1 := func(nodeID string, backupID uuid.UUID) { receivedBackups1 <- backupID }
handler2 := func(nodeID string, backupID uuid.UUID) { receivedBackups2 <- backupID }
handler3 := func(nodeID string, backupID uuid.UUID) { receivedBackups3 <- backupID }
err := registry1.SubscribeForBackupsCompletions(handler1)
assert.NoError(t, err)
err = registry2.SubscribeForBackupsCompletions(handler2)
assert.NoError(t, err)
err = registry3.SubscribeForBackupsCompletions(handler3)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
publishRegistry := createTestRegistry()
err = publishRegistry.PublishBackupCompletion(node1.ID.String(), backupID1)
assert.NoError(t, err)
err = publishRegistry.PublishBackupCompletion(node2.ID.String(), backupID2)
assert.NoError(t, err)
err = publishRegistry.PublishBackupCompletion(node3.ID.String(), backupID3)
assert.NoError(t, err)
receivedAll1 := []uuid.UUID{}
receivedAll2 := []uuid.UUID{}
receivedAll3 := []uuid.UUID{}
for i := 0; i < 3; i++ {
select {
case received := <-receivedBackups1:
receivedAll1 = append(receivedAll1, received)
case <-time.After(2 * time.Second):
t.Fatal("Subscriber 1 timeout waiting for completion message")
}
}
for i := 0; i < 3; i++ {
select {
case received := <-receivedBackups2:
receivedAll2 = append(receivedAll2, received)
case <-time.After(2 * time.Second):
t.Fatal("Subscriber 2 timeout waiting for completion message")
}
}
for i := 0; i < 3; i++ {
select {
case received := <-receivedBackups3:
receivedAll3 = append(receivedAll3, received)
case <-time.After(2 * time.Second):
t.Fatal("Subscriber 3 timeout waiting for completion message")
}
}
assert.Contains(t, receivedAll1, backupID1)
assert.Contains(t, receivedAll1, backupID2)
assert.Contains(t, receivedAll1, backupID3)
assert.Contains(t, receivedAll2, backupID1)
assert.Contains(t, receivedAll2, backupID2)
assert.Contains(t, receivedAll2, backupID3)
assert.Contains(t, receivedAll3, backupID1)
assert.Contains(t, receivedAll3, backupID2)
assert.Contains(t, receivedAll3, backupID3)
}

View File

@@ -0,0 +1,600 @@
package backuping
import (
"context"
"databasus-backend/internal/config"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/storages"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/period"
"fmt"
"log/slog"
"time"
"github.com/google/uuid"
)
type BackupsScheduler struct {
backupRepository *backups_core.BackupRepository
backupConfigService *backups_config.BackupConfigService
storageService *storages.StorageService
backupCancelManager *backups_cancellation.BackupCancelManager
nodesRegistry *BackupNodesRegistry
lastBackupTime time.Time
logger *slog.Logger
backupToNodeRelations map[uuid.UUID]BackupToNodeRelation
backuperNode *BackuperNode
}
func (s *BackupsScheduler) Run(ctx context.Context) {
s.lastBackupTime = time.Now().UTC()
if config.GetEnv().IsManyNodesMode {
// wait other nodes to start
time.Sleep(1 * time.Minute)
}
if err := s.failBackupsInProgress(); err != nil {
s.logger.Error("Failed to fail backups in progress", "error", err)
panic(err)
}
if err := s.nodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted); err != nil {
s.logger.Error("Failed to subscribe to backup completions", "error", err)
panic(err)
}
defer func() {
if err := s.nodesRegistry.UnsubscribeForBackupsCompletions(); err != nil {
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
}
}()
if ctx.Err() != nil {
return
}
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.cleanOldBackups(); err != nil {
s.logger.Error("Failed to clean old backups", "error", err)
}
if err := s.checkDeadNodesAndFailBackups(); err != nil {
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
}
if err := s.runPendingBackups(); err != nil {
s.logger.Error("Failed to run pending backups", "error", err)
}
s.lastBackupTime = time.Now().UTC()
}
}
}
func (s *BackupsScheduler) IsSchedulerRunning() bool {
// if last backup time is more than 5 minutes ago, return false
return s.lastBackupTime.After(time.Now().UTC().Add(-5 * time.Minute))
}
func (s *BackupsScheduler) failBackupsInProgress() error {
backupsInProgress, err := s.backupRepository.FindByStatus(backups_core.BackupStatusInProgress)
if err != nil {
return err
}
fmt.Println("Backups in progress", len(backupsInProgress))
for _, backup := range backupsInProgress {
if err := s.backupCancelManager.CancelBackup(backup.ID); err != nil {
s.logger.Error(
"Failed to cancel backup via context manager",
"backupId",
backup.ID,
"error",
err,
)
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(backup.DatabaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
continue
}
failMessage := "Backup failed due to application restart"
backup.FailMessage = &failMessage
backup.Status = backups_core.BackupStatusFailed
backup.BackupSizeMb = 0
s.backuperNode.SendBackupNotification(
backupConfig,
backup,
backups_config.NotificationBackupFailed,
&failMessage,
)
if err := s.backupRepository.Save(backup); err != nil {
return err
}
}
return nil
}
func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool) {
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(databaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
return
}
if backupConfig.StorageID == nil {
s.logger.Error("Backup config storage ID is nil", "databaseId", databaseID)
return
}
leastBusyNodeID, err := s.calculateLeastBusyNode()
if err != nil {
s.logger.Error(
"Failed to calculate least busy node",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
return
}
backup := &backups_core.Backup{
DatabaseID: backupConfig.DatabaseID,
StorageID: *backupConfig.StorageID,
Status: backups_core.BackupStatusInProgress,
BackupSizeMb: 0,
CreatedAt: time.Now().UTC(),
}
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error(
"Failed to save backup",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
return
}
if err := s.nodesRegistry.IncrementBackupsInProgress(leastBusyNodeID.String()); err != nil {
s.logger.Error(
"Failed to increment backups in progress",
"nodeId",
leastBusyNodeID,
"backupId",
backup.ID,
"error",
err,
)
return
}
if err := s.nodesRegistry.AssignBackupToNode(leastBusyNodeID.String(), backup.ID, isCallNotifier); err != nil {
s.logger.Error(
"Failed to submit backup",
"nodeId",
leastBusyNodeID,
"backupId",
backup.ID,
"error",
err,
)
if decrementErr := s.nodesRegistry.DecrementBackupsInProgress(leastBusyNodeID.String()); decrementErr != nil {
s.logger.Error(
"Failed to decrement backups in progress after submit failure",
"nodeId",
leastBusyNodeID,
"error",
decrementErr,
)
}
return
}
if relation, exists := s.backupToNodeRelations[*leastBusyNodeID]; exists {
relation.BackupsIDs = append(relation.BackupsIDs, backup.ID)
s.backupToNodeRelations[*leastBusyNodeID] = relation
} else {
s.backupToNodeRelations[*leastBusyNodeID] = BackupToNodeRelation{
NodeID: *leastBusyNodeID,
BackupsIDs: []uuid.UUID{backup.ID},
}
}
s.logger.Info(
"Successfully triggered scheduled backup",
"databaseId",
backupConfig.DatabaseID,
"backupId",
backup.ID,
"nodeId",
leastBusyNodeID,
)
}
// GetRemainedBackupTryCount returns the number of remaining backup tries for a given backup.
// If the backup is not failed or the backup config does not allow retries, it returns 0.
// If the backup is failed and the backup config allows retries, it returns the number of remaining tries.
// If the backup is failed and the backup config does not allow retries, it returns 0.
func (s *BackupsScheduler) GetRemainedBackupTryCount(lastBackup *backups_core.Backup) int {
if lastBackup == nil {
return 0
}
if lastBackup.Status != backups_core.BackupStatusFailed {
return 0
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(lastBackup.DatabaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
return 0
}
if !backupConfig.IsRetryIfFailed {
return 0
}
maxFailedTriesCount := backupConfig.MaxFailedTriesCount
lastBackups, err := s.backupRepository.FindByDatabaseIDWithLimit(
lastBackup.DatabaseID,
maxFailedTriesCount,
)
if err != nil {
s.logger.Error("Failed to find last backups by database ID", "error", err)
return 0
}
lastFailedBackups := make([]*backups_core.Backup, 0)
for _, backup := range lastBackups {
if backup.Status == backups_core.BackupStatusFailed {
lastFailedBackups = append(lastFailedBackups, backup)
}
}
return maxFailedTriesCount - len(lastFailedBackups)
}
func (s *BackupsScheduler) cleanOldBackups() error {
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
backupStorePeriod := backupConfig.StorePeriod
if backupStorePeriod == period.PeriodForever {
continue
}
storeDuration := backupStorePeriod.ToDuration()
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
oldBackups, err := s.backupRepository.FindBackupsBeforeDate(
backupConfig.DatabaseID,
dateBeforeBackupsShouldBeDeleted,
)
if err != nil {
s.logger.Error(
"Failed to find old backups for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
continue
}
for _, backup := range oldBackups {
storage, err := s.storageService.GetStorageByID(backup.StorageID)
if err != nil {
s.logger.Error(
"Failed to get storage by ID",
"storageId",
backup.StorageID,
"error",
err,
)
continue
}
encryptor := encryption.GetFieldEncryptor()
err = storage.DeleteFile(encryptor, backup.ID)
if err != nil {
s.logger.Error("Failed to delete backup file", "backupId", backup.ID, "error", err)
}
if err := s.backupRepository.DeleteByID(backup.ID); err != nil {
s.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
continue
}
s.logger.Info(
"Deleted old backup",
"backupId",
backup.ID,
"databaseId",
backupConfig.DatabaseID,
)
}
}
return nil
}
func (s *BackupsScheduler) runPendingBackups() error {
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
if backupConfig.BackupInterval == nil {
continue
}
lastBackup, err := s.backupRepository.FindLastByDatabaseID(backupConfig.DatabaseID)
if err != nil {
s.logger.Error(
"Failed to get last backup for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
continue
}
var lastBackupTime *time.Time
if lastBackup != nil {
lastBackupTime = &lastBackup.CreatedAt
}
remainedBackupTryCount := s.GetRemainedBackupTryCount(lastBackup)
if backupConfig.BackupInterval.ShouldTriggerBackup(time.Now().UTC(), lastBackupTime) ||
remainedBackupTryCount > 0 {
s.logger.Info(
"Triggering scheduled backup",
"databaseId",
backupConfig.DatabaseID,
"intervalType",
backupConfig.BackupInterval.Interval,
)
s.StartBackup(backupConfig.DatabaseID, remainedBackupTryCount == 1)
continue
}
}
return nil
}
func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
nodes, err := s.nodesRegistry.GetAvailableNodes()
if err != nil {
return nil, fmt.Errorf("failed to get available nodes: %w", err)
}
if len(nodes) == 0 {
return nil, fmt.Errorf("no nodes available")
}
stats, err := s.nodesRegistry.GetBackupNodesStats()
if err != nil {
return nil, fmt.Errorf("failed to get backup nodes stats: %w", err)
}
statsMap := make(map[uuid.UUID]int)
for _, stat := range stats {
statsMap[stat.ID] = stat.ActiveBackups
}
var bestNode *BackupNode
var bestScore float64 = -1
now := time.Now().UTC()
for i := range nodes {
node := &nodes[i]
if now.Sub(node.LastHeartbeat) > 2*time.Minute {
continue
}
activeBackups := statsMap[node.ID]
var score float64
if node.ThroughputMBs > 0 {
score = float64(activeBackups) / float64(node.ThroughputMBs)
} else {
score = float64(activeBackups) * 1000
}
if bestNode == nil || score < bestScore {
bestNode = node
bestScore = score
}
}
if bestNode == nil {
return nil, fmt.Errorf("no suitable nodes available")
}
return &bestNode.ID, nil
}
func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUID) {
nodeID, err := uuid.Parse(nodeIDStr)
if err != nil {
s.logger.Error(
"Failed to parse node ID from completion message",
"nodeId",
nodeIDStr,
"error",
err,
)
return
}
relation, exists := s.backupToNodeRelations[nodeID]
if !exists {
s.logger.Warn(
"Received completion for unknown node",
"nodeId",
nodeID,
"backupId",
backupID,
)
return
}
newBackupIDs := make([]uuid.UUID, 0)
found := false
for _, id := range relation.BackupsIDs {
if id == backupID {
found = true
continue
}
newBackupIDs = append(newBackupIDs, id)
}
if !found {
s.logger.Warn(
"Backup not found in node's backup list",
"nodeId",
nodeID,
"backupId",
backupID,
)
return
}
if len(newBackupIDs) == 0 {
delete(s.backupToNodeRelations, nodeID)
} else {
relation.BackupsIDs = newBackupIDs
s.backupToNodeRelations[nodeID] = relation
}
if err := s.nodesRegistry.DecrementBackupsInProgress(nodeIDStr); err != nil {
s.logger.Error(
"Failed to decrement backups in progress",
"nodeId",
nodeID,
"backupId",
backupID,
"error",
err,
)
}
}
func (s *BackupsScheduler) checkDeadNodesAndFailBackups() error {
nodes, err := s.nodesRegistry.GetAvailableNodes()
if err != nil {
return fmt.Errorf("failed to get available nodes: %w", err)
}
aliveNodeIDs := make(map[uuid.UUID]bool)
now := time.Now().UTC()
for _, node := range nodes {
if now.Sub(node.LastHeartbeat) <= 2*time.Minute {
aliveNodeIDs[node.ID] = true
}
}
for nodeID, relation := range s.backupToNodeRelations {
if aliveNodeIDs[nodeID] {
continue
}
s.logger.Warn(
"Node is dead, failing its backups",
"nodeId",
nodeID,
"backupCount",
len(relation.BackupsIDs),
)
for _, backupID := range relation.BackupsIDs {
backup, err := s.backupRepository.FindByID(backupID)
if err != nil {
s.logger.Error(
"Failed to find backup for dead node",
"nodeId",
nodeID,
"backupId",
backupID,
"error",
err,
)
continue
}
failMessage := "Backup failed due to node unavailability"
backup.FailMessage = &failMessage
backup.Status = backups_core.BackupStatusFailed
backup.BackupSizeMb = 0
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error(
"Failed to save failed backup for dead node",
"nodeId",
nodeID,
"backupId",
backupID,
"error",
err,
)
continue
}
if err := s.nodesRegistry.DecrementBackupsInProgress(nodeID.String()); err != nil {
s.logger.Error(
"Failed to decrement backups in progress for dead node",
"nodeId",
nodeID,
"backupId",
backupID,
"error",
err,
)
}
s.logger.Info(
"Failed backup due to dead node",
"nodeId",
nodeID,
"backupId",
backupID,
)
}
delete(s.backupToNodeRelations, nodeID)
}
return nil
}

View File

@@ -1,6 +1,7 @@
package backups
package backuping
import (
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/intervals"
@@ -9,14 +10,21 @@ import (
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/period"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
func Test_RunPendingBackups_WhenLastBackupWasYesterday_CreatesNewBackup(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
// setup data
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
@@ -57,16 +65,16 @@ func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
assert.NoError(t, err)
// add old backup
backupRepository.Save(&Backup{
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusCompleted,
Status: backups_core.BackupStatusCompleted,
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
})
GetBackupBackgroundService().runPendingBackups()
GetBackupsScheduler().runPendingBackups()
// Wait for backup to complete (runs in goroutine)
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
@@ -80,7 +88,12 @@ func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
time.Sleep(200 * time.Millisecond)
}
func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
func Test_RunPendingBackups_WhenLastBackupWasRecentlyCompleted_SkipsBackup(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
// setup data
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
@@ -121,16 +134,16 @@ func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
assert.NoError(t, err)
// add recent backup (1 hour ago)
backupRepository.Save(&Backup{
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusCompleted,
Status: backups_core.BackupStatusCompleted,
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
})
GetBackupBackgroundService().runPendingBackups()
GetBackupsScheduler().runPendingBackups()
time.Sleep(100 * time.Millisecond)
@@ -143,7 +156,12 @@ func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
time.Sleep(200 * time.Millisecond)
}
func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T) {
func Test_RunPendingBackups_WhenLastBackupFailedAndRetriesDisabled_SkipsBackup(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
// setup data
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
@@ -187,17 +205,17 @@ func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T)
// add failed backup
failMessage := "backup failed"
backupRepository.Save(&Backup{
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusFailed,
Status: backups_core.BackupStatusFailed,
FailMessage: &failMessage,
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
})
GetBackupBackgroundService().runPendingBackups()
GetBackupsScheduler().runPendingBackups()
time.Sleep(100 * time.Millisecond)
@@ -210,7 +228,12 @@ func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T)
time.Sleep(200 * time.Millisecond)
}
func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
func Test_RunPendingBackups_WhenLastBackupFailedAndRetriesEnabled_CreatesNewBackup(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
// setup data
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
@@ -254,17 +277,17 @@ func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
// add failed backup
failMessage := "backup failed"
backupRepository.Save(&Backup{
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusFailed,
Status: backups_core.BackupStatusFailed,
FailMessage: &failMessage,
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
})
GetBackupBackgroundService().runPendingBackups()
GetBackupsScheduler().runPendingBackups()
// Wait for backup to complete (runs in goroutine)
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
@@ -278,7 +301,12 @@ func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
time.Sleep(200 * time.Millisecond)
}
func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *testing.T) {
func Test_RunPendingBackups_WhenFailedBackupsExceedMaxRetries_SkipsBackup(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
// setup data
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
@@ -322,19 +350,19 @@ func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *tes
failMessage := "backup failed"
for i := 0; i < 3; i++ {
backupRepository.Save(&Backup{
for range 3 {
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusFailed,
Status: backups_core.BackupStatusFailed,
FailMessage: &failMessage,
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
})
}
GetBackupBackgroundService().runPendingBackups()
GetBackupsScheduler().runPendingBackups()
time.Sleep(100 * time.Millisecond)
@@ -347,7 +375,12 @@ func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *tes
time.Sleep(200 * time.Millisecond)
}
func Test_MakeBackgroundBackupWhenBakupsDisabled_BackupSkipped(t *testing.T) {
func Test_RunPendingBackups_WhenBackupsDisabled_SkipsBackup(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
@@ -385,16 +418,16 @@ func Test_MakeBackgroundBackupWhenBakupsDisabled_BackupSkipped(t *testing.T) {
assert.NoError(t, err)
// add old backup that would trigger new backup if enabled
backupRepository.Save(&Backup{
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusCompleted,
Status: backups_core.BackupStatusCompleted,
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
})
GetBackupBackgroundService().runPendingBackups()
GetBackupsScheduler().runPendingBackups()
time.Sleep(100 * time.Millisecond)
@@ -405,3 +438,272 @@ func Test_MakeBackgroundBackupWhenBakupsDisabled_BackupSkipped(t *testing.T) {
// Wait for any cleanup operations to complete before defer cleanup runs
time.Sleep(200 * time.Millisecond)
}
func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegistry(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Register mock node without subscribing to backups (simulates node crash after registration)
mockNodeID := uuid.New()
err = CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC())
assert.NoError(t, err)
// Scheduler assigns backup to mock node
GetBackupsScheduler().StartBackup(database.ID, false)
time.Sleep(100 * time.Millisecond)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status)
// Verify Valkey counter was incremented when backup was assigned
stats, err := nodesRegistry.GetBackupNodesStats()
assert.NoError(t, err)
foundStat := false
for _, stat := range stats {
if stat.ID == mockNodeID {
assert.Equal(t, 1, stat.ActiveBackups)
foundStat = true
break
}
}
assert.True(t, foundStat, "Node stats should be present")
// Simulate node death by setting heartbeat older than 2-minute threshold
oldHeartbeat := time.Now().UTC().Add(-3 * time.Minute)
err = UpdateNodeHeartbeatDirectly(mockNodeID, 100, oldHeartbeat)
assert.NoError(t, err)
// Trigger dead node detection
err = GetBackupsScheduler().checkDeadNodesAndFailBackups()
assert.NoError(t, err)
// Verify backup was failed with appropriate error message
backups, err = backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
assert.Equal(t, backups_core.BackupStatusFailed, backups[0].Status)
assert.NotNil(t, backups[0].FailMessage)
assert.Contains(t, *backups[0].FailMessage, "node unavailability")
// Verify Valkey counter was decremented after backup failed
stats, err = nodesRegistry.GetBackupNodesStats()
assert.NoError(t, err)
for _, stat := range stats {
if stat.ID == mockNodeID {
assert.Equal(t, 0, stat.ActiveBackups)
}
}
// Node info should still exist in registry (not removed by checkDeadNodesAndFailBackups)
node, err := GetNodeFromRegistry(mockNodeID)
assert.NoError(t, err)
assert.NotNil(t, node)
assert.Equal(t, mockNodeID, node.ID)
time.Sleep(200 * time.Millisecond)
}
func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
t.Run("Nodes with same throughput", func(t *testing.T) {
cache_utils.ClearAllCache()
node1ID := uuid.New()
node2ID := uuid.New()
node3ID := uuid.New()
now := time.Now().UTC()
err := CreateMockNodeInRegistry(node1ID, 100, now)
assert.NoError(t, err)
err = CreateMockNodeInRegistry(node2ID, 100, now)
assert.NoError(t, err)
err = CreateMockNodeInRegistry(node3ID, 100, now)
assert.NoError(t, err)
for range 5 {
err = nodesRegistry.IncrementBackupsInProgress(node1ID.String())
assert.NoError(t, err)
}
for range 2 {
err = nodesRegistry.IncrementBackupsInProgress(node2ID.String())
assert.NoError(t, err)
}
for range 8 {
err = nodesRegistry.IncrementBackupsInProgress(node3ID.String())
assert.NoError(t, err)
}
leastBusyNodeID, err := GetBackupsScheduler().calculateLeastBusyNode()
assert.NoError(t, err)
assert.NotNil(t, leastBusyNodeID)
assert.Equal(t, node2ID, *leastBusyNodeID)
})
t.Run("Nodes with different throughput", func(t *testing.T) {
cache_utils.ClearAllCache()
node100MBsID := uuid.New()
node50MBsID := uuid.New()
now := time.Now().UTC()
err := CreateMockNodeInRegistry(node100MBsID, 100, now)
assert.NoError(t, err)
err = CreateMockNodeInRegistry(node50MBsID, 50, now)
assert.NoError(t, err)
for range 10 {
err = nodesRegistry.IncrementBackupsInProgress(node100MBsID.String())
assert.NoError(t, err)
}
err = nodesRegistry.IncrementBackupsInProgress(node50MBsID.String())
assert.NoError(t, err)
leastBusyNodeID, err := GetBackupsScheduler().calculateLeastBusyNode()
assert.NoError(t, err)
assert.NotNil(t, leastBusyNodeID)
assert.Equal(t, node50MBsID, *leastBusyNodeID)
})
}
func Test_FailBackupsInProgress_WhenSchedulerStarts_CancelsBackupsAndUpdatesStatus(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Create two in-progress backups that should be failed on scheduler restart
backup1 := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
BackupSizeMb: 10.5,
CreatedAt: time.Now().UTC().Add(-30 * time.Minute),
}
err = backupRepository.Save(backup1)
assert.NoError(t, err)
backup2 := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
BackupSizeMb: 5.2,
CreatedAt: time.Now().UTC().Add(-15 * time.Minute),
}
err = backupRepository.Save(backup2)
assert.NoError(t, err)
// Create a completed backup to verify it's not affected by failBackupsInProgress
completedBackup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 20.0,
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
}
err = backupRepository.Save(completedBackup)
assert.NoError(t, err)
// Trigger the scheduler's failBackupsInProgress logic
// This should cancel in-progress backups and mark them as failed
err = GetBackupsScheduler().failBackupsInProgress()
assert.NoError(t, err)
// Verify all backups exist and were processed correctly
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 3)
var failedCount int
var completedCount int
for _, backup := range backups {
switch backup.Status {
case backups_core.BackupStatusFailed:
failedCount++
// Verify fail message indicates application restart
assert.NotNil(t, backup.FailMessage)
assert.Equal(t, "Backup failed due to application restart", *backup.FailMessage)
// Verify backup size was reset to 0
assert.Equal(t, float64(0), backup.BackupSizeMb)
case backups_core.BackupStatusCompleted:
completedCount++
}
}
// Verify correct number of backups in each state
assert.Equal(t, 2, failedCount)
assert.Equal(t, 1, completedCount)
time.Sleep(200 * time.Millisecond)
}

View File

@@ -0,0 +1,206 @@
package backuping
import (
"context"
"fmt"
"testing"
"time"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/usecases"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_services "databasus-backend/internal/features/workspaces/services"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
func CreateTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
backups_config.GetBackupConfigController(),
)
return router
}
func CreateTestBackuperNode() *BackuperNode {
return &BackuperNode{
databases.GetDatabaseService(),
encryption.GetFieldEncryptor(),
workspaces_services.GetWorkspaceService(),
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
notifiers.GetNotifierService(),
backupCancelManager,
nodesRegistry,
logger.GetLogger(),
usecases.GetCreateBackupUsecase(),
uuid.New(),
time.Time{},
}
}
// WaitForBackupCompletion waits for a new backup to be created and completed (or failed)
// for the given database. It checks for backups with count greater than expectedInitialCount.
func WaitForBackupCompletion(
t *testing.T,
databaseID uuid.UUID,
expectedInitialCount int,
timeout time.Duration,
) {
deadline := time.Now().UTC().Add(timeout)
for time.Now().UTC().Before(deadline) {
backups, err := backupRepository.FindByDatabaseID(databaseID)
if err != nil {
t.Logf("WaitForBackupCompletion: error finding backups: %v", err)
time.Sleep(50 * time.Millisecond)
continue
}
t.Logf(
"WaitForBackupCompletion: found %d backups (expected > %d)",
len(backups),
expectedInitialCount,
)
if len(backups) > expectedInitialCount {
// Check if the newest backup has completed or failed
newestBackup := backups[0]
t.Logf("WaitForBackupCompletion: newest backup status: %s", newestBackup.Status)
if newestBackup.Status == backups_core.BackupStatusCompleted ||
newestBackup.Status == backups_core.BackupStatusFailed ||
newestBackup.Status == backups_core.BackupStatusCanceled {
t.Logf(
"WaitForBackupCompletion: backup finished with status %s",
newestBackup.Status,
)
return
}
}
time.Sleep(50 * time.Millisecond)
}
t.Logf("WaitForBackupCompletion: timeout waiting for backup to complete")
}
// StartBackuperNodeForTest starts a BackuperNode in a goroutine for testing.
// The node registers itself in the registry and subscribes to backup assignments.
// Returns a context cancel function that should be deferred to stop the node.
func StartBackuperNodeForTest(t *testing.T, backuperNode *BackuperNode) context.CancelFunc {
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
backuperNode.Run(ctx)
close(done)
}()
// Poll registry for node presence instead of fixed sleep
deadline := time.Now().UTC().Add(5 * time.Second)
for time.Now().UTC().Before(deadline) {
nodes, err := nodesRegistry.GetAvailableNodes()
if err == nil {
for _, node := range nodes {
if node.ID == backuperNode.nodeID {
t.Logf("BackuperNode registered in registry: %s", backuperNode.nodeID)
return func() {
cancel()
select {
case <-done:
t.Log("BackuperNode stopped gracefully")
case <-time.After(2 * time.Second):
t.Log("BackuperNode stop timeout")
}
}
}
}
}
time.Sleep(50 * time.Millisecond)
}
t.Fatalf("BackuperNode failed to register in registry within timeout")
return nil
}
// StopBackuperNodeForTest stops the BackuperNode by canceling its context.
// It waits for the node to unregister from the registry.
func StopBackuperNodeForTest(t *testing.T, cancel context.CancelFunc, backuperNode *BackuperNode) {
cancel()
// Wait for node to unregister from registry
deadline := time.Now().UTC().Add(2 * time.Second)
for time.Now().UTC().Before(deadline) {
nodes, err := nodesRegistry.GetAvailableNodes()
if err == nil {
found := false
for _, node := range nodes {
if node.ID == backuperNode.nodeID {
found = true
break
}
}
if !found {
t.Logf("BackuperNode unregistered from registry: %s", backuperNode.nodeID)
return
}
}
time.Sleep(50 * time.Millisecond)
}
t.Logf("BackuperNode stop completed for %s", backuperNode.nodeID)
}
func CreateMockNodeInRegistry(nodeID uuid.UUID, throughputMBs int, lastHeartbeat time.Time) error {
backupNode := BackupNode{
ID: nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: lastHeartbeat,
}
return nodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
}
func UpdateNodeHeartbeatDirectly(
nodeID uuid.UUID,
throughputMBs int,
lastHeartbeat time.Time,
) error {
backupNode := BackupNode{
ID: nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: lastHeartbeat,
}
return nodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
}
func GetNodeFromRegistry(nodeID uuid.UUID) (*BackupNode, error) {
nodes, err := nodesRegistry.GetAvailableNodes()
if err != nil {
return nil, err
}
for _, node := range nodes {
if node.ID == nodeID {
return &node, nil
}
}
return nil, fmt.Errorf("node not found")
}

View File

@@ -1,9 +1,8 @@
package backups
package backups_cancellation
import (
"context"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/logger"
"log/slog"
"sync"
@@ -12,22 +11,14 @@ import (
const backupCancelChannel = "backup:cancel"
type BackupContextManager struct {
type BackupCancelManager struct {
mu sync.RWMutex
cancelFuncs map[uuid.UUID]context.CancelFunc
pubsub *cache_utils.PubSubManager
logger *slog.Logger
}
func NewBackupContextManager() *BackupContextManager {
return &BackupContextManager{
cancelFuncs: make(map[uuid.UUID]context.CancelFunc),
pubsub: cache_utils.NewPubSubManager(),
logger: logger.GetLogger(),
}
}
func (m *BackupContextManager) StartSubscription() {
func (m *BackupCancelManager) StartSubscription() {
ctx := context.Background()
handler := func(message string) {
@@ -56,14 +47,14 @@ func (m *BackupContextManager) StartSubscription() {
}
}
func (m *BackupContextManager) RegisterBackup(backupID uuid.UUID, cancelFunc context.CancelFunc) {
func (m *BackupCancelManager) RegisterBackup(backupID uuid.UUID, cancelFunc context.CancelFunc) {
m.mu.Lock()
defer m.mu.Unlock()
m.cancelFuncs[backupID] = cancelFunc
m.logger.Debug("Registered backup", "backupID", backupID)
}
func (m *BackupContextManager) CancelBackup(backupID uuid.UUID) error {
func (m *BackupCancelManager) CancelBackup(backupID uuid.UUID) error {
ctx := context.Background()
err := m.pubsub.Publish(ctx, backupCancelChannel, backupID.String())
@@ -76,7 +67,7 @@ func (m *BackupContextManager) CancelBackup(backupID uuid.UUID) error {
return nil
}
func (m *BackupContextManager) UnregisterBackup(backupID uuid.UUID) {
func (m *BackupCancelManager) UnregisterBackup(backupID uuid.UUID) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.cancelFuncs, backupID)

View File

@@ -1,4 +1,4 @@
package backups
package backups_cancellation
import (
"context"
@@ -11,7 +11,7 @@ import (
)
func Test_RegisterBackup_BackupRegisteredSuccessfully(t *testing.T) {
manager := NewBackupContextManager()
manager := backupCancelManager
backupID := uuid.New()
_, cancel := context.WithCancel(context.Background())
@@ -26,7 +26,7 @@ func Test_RegisterBackup_BackupRegisteredSuccessfully(t *testing.T) {
}
func Test_UnregisterBackup_BackupUnregisteredSuccessfully(t *testing.T) {
manager := NewBackupContextManager()
manager := backupCancelManager
backupID := uuid.New()
_, cancel := context.WithCancel(context.Background())
@@ -42,7 +42,7 @@ func Test_UnregisterBackup_BackupUnregisteredSuccessfully(t *testing.T) {
}
func Test_CancelBackup_OnSameInstance_BackupCancelledViaPubSub(t *testing.T) {
manager := NewBackupContextManager()
manager := backupCancelManager
backupID := uuid.New()
ctx, cancel := context.WithCancel(context.Background())
@@ -75,8 +75,8 @@ func Test_CancelBackup_OnSameInstance_BackupCancelledViaPubSub(t *testing.T) {
}
func Test_CancelBackup_FromDifferentInstance_BackupCancelledOnRunningInstance(t *testing.T) {
manager1 := NewBackupContextManager()
manager2 := NewBackupContextManager()
manager1 := backupCancelManager
manager2 := backupCancelManager
backupID := uuid.New()
ctx, cancel := context.WithCancel(context.Background())
@@ -111,7 +111,7 @@ func Test_CancelBackup_FromDifferentInstance_BackupCancelledOnRunningInstance(t
}
func Test_CancelBackup_WhenBackupDoesNotExist_NoErrorReturned(t *testing.T) {
manager := NewBackupContextManager()
manager := backupCancelManager
manager.StartSubscription()
time.Sleep(100 * time.Millisecond)
@@ -122,7 +122,7 @@ func Test_CancelBackup_WhenBackupDoesNotExist_NoErrorReturned(t *testing.T) {
}
func Test_CancelBackup_WithMultipleBackups_AllBackupsCancelled(t *testing.T) {
manager := NewBackupContextManager()
manager := backupCancelManager
numBackups := 5
backupIDs := make([]uuid.UUID, numBackups)
@@ -165,7 +165,7 @@ func Test_CancelBackup_WithMultipleBackups_AllBackupsCancelled(t *testing.T) {
}
func Test_CancelBackup_AfterUnregister_BackupNotCancelled(t *testing.T) {
manager := NewBackupContextManager()
manager := backupCancelManager
backupID := uuid.New()
_, cancel := context.WithCancel(context.Background())

View File

@@ -0,0 +1,25 @@
package backups_cancellation
import (
"context"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/logger"
"sync"
"github.com/google/uuid"
)
var backupCancelManager = &BackupCancelManager{
sync.RWMutex{},
make(map[uuid.UUID]context.CancelFunc),
cache_utils.NewPubSubManager(),
logger.GetLogger(),
}
func GetBackupCancelManager() *BackupCancelManager {
return backupCancelManager
}
func SetupDependencies() {
backupCancelManager.StartSubscription()
}

View File

@@ -1,11 +1,15 @@
package backups
import (
"context"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_download "databasus-backend/internal/features/backups/backups/download"
"databasus-backend/internal/features/databases"
users_middleware "databasus-backend/internal/features/users/middleware"
"fmt"
"io"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -170,9 +174,10 @@ func (c *BackupController) CancelBackup(ctx *gin.Context) {
// @Description Generate a token for downloading a backup file (valid for 5 minutes)
// @Tags backups
// @Param id path string true "Backup ID"
// @Success 200 {object} GenerateDownloadTokenResponse
// @Success 200 {object} backups_download.GenerateDownloadTokenResponse
// @Failure 400
// @Failure 401
// @Failure 409 {object} map[string]string "Download already in progress"
// @Router /backups/{id}/download-token [post]
func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
@@ -189,6 +194,15 @@ func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) {
response, err := c.backupService.GenerateDownloadToken(user, id)
if err != nil {
if err == backups_download.ErrDownloadAlreadyInProgress {
ctx.JSON(
http.StatusConflict,
gin.H{
"error": "Download already in progress for some of backups. Please wait until previous download completed or cancel it",
},
)
return
}
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
@@ -198,14 +212,22 @@ func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) {
// GetFile
// @Summary Download a backup file
// @Description Download the backup file for the specified backup using a download token
// @Description Download the backup file for the specified backup using a download token.
// @Description
// @Description **Download Concurrency Control:**
// @Description - Only one download per user is allowed at a time
// @Description - If a download is already in progress, returns 409 Conflict
// @Description - Downloads are tracked using cache with 5-second TTL and 3-second heartbeat
// @Description - Browser cancellations automatically release the download lock
// @Description - Server crashes are handled via automatic cache expiry (5 seconds)
// @Tags backups
// @Param id path string true "Backup ID"
// @Param token query string true "Download token"
// @Success 200 {file} file
// @Failure 400
// @Failure 401
// @Failure 500
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 409 {object} map[string]string "Download already in progress"
// @Failure 500 {object} map[string]string
// @Router /backups/{id}/file [get]
func (c *BackupController) GetFile(ctx *gin.Context) {
token := ctx.Query("token")
@@ -214,7 +236,6 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
return
}
// Get backup ID from URL
backupIDParam := ctx.Param("id")
backupID, err := uuid.Parse(backupIDParam)
if err != nil {
@@ -222,13 +243,22 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
return
}
downloadToken, err := c.backupService.ValidateDownloadToken(token)
downloadToken, rateLimiter, err := c.backupService.ValidateDownloadToken(token)
if err != nil {
if err == backups_download.ErrDownloadAlreadyInProgress {
ctx.JSON(
http.StatusConflict,
gin.H{
"error": "download already in progress for this user. Please wait until previous download completed or cancel it",
},
)
return
}
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired download token"})
return
}
// Verify token is for the requested backup
if downloadToken.BackupID != backupID {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired download token"})
return
@@ -238,18 +268,28 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
downloadToken.BackupID,
)
if err != nil {
c.backupService.UnregisterDownload(downloadToken.UserID)
c.backupService.ReleaseDownloadLock(downloadToken.UserID)
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
rateLimitedReader := backups_download.NewRateLimitedReader(fileReader, rateLimiter)
heartbeatCtx, cancelHeartbeat := context.WithCancel(context.Background())
defer func() {
if err := fileReader.Close(); err != nil {
cancelHeartbeat()
c.backupService.UnregisterDownload(downloadToken.UserID)
c.backupService.ReleaseDownloadLock(downloadToken.UserID)
if err := rateLimitedReader.Close(); err != nil {
fmt.Printf("Error closing file reader: %v\n", err)
}
}()
go c.startDownloadHeartbeat(heartbeatCtx, downloadToken.UserID)
filename := c.generateBackupFilename(backup, database)
// Set Content-Length for progress tracking
if backup.BackupSizeMb > 0 {
sizeBytes := int64(backup.BackupSizeMb * 1024 * 1024)
ctx.Header("Content-Length", fmt.Sprintf("%d", sizeBytes))
@@ -261,13 +301,12 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
fmt.Sprintf("attachment; filename=\"%s\"", filename),
)
_, err = io.Copy(ctx.Writer, fileReader)
_, err = io.Copy(ctx.Writer, rateLimitedReader)
if err != nil {
fmt.Printf("Error streaming file: %v\n", err)
return
}
// Write audit log after successful download
c.backupService.WriteAuditLogForDownload(downloadToken.UserID, backup, database)
}
@@ -276,7 +315,7 @@ type MakeBackupRequest struct {
}
func (c *BackupController) generateBackupFilename(
backup *Backup,
backup *backups_core.Backup,
database *databases.Database,
) string {
// Format timestamp as YYYY-MM-DD_HH-mm-ss
@@ -333,3 +372,17 @@ func sanitizeFilename(name string) string {
return string(result)
}
func (c *BackupController) startDownloadHeartbeat(ctx context.Context, userID uuid.UUID) {
ticker := time.NewTicker(backups_download.GetDownloadHeartbeatInterval())
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
c.backupService.RefreshDownloadLock(userID)
}
}
}

View File

@@ -18,7 +18,8 @@ import (
"databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups/download_token"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_download "databasus-backend/internal/features/backups/backups/download"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/postgresql"
@@ -478,7 +479,7 @@ func Test_GenerateDownloadToken_PermissionsEnforced(t *testing.T) {
)
if tt.expectSuccess {
var response GenerateDownloadTokenResponse
var response backups_download.GenerateDownloadTokenResponse
err := json.Unmarshal(testResp.Body, &response)
assert.NoError(t, err)
assert.NotEmpty(t, response.Token)
@@ -499,7 +500,7 @@ func Test_DownloadBackup_WithValidToken_Success(t *testing.T) {
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
// Generate download token
var tokenResponse GenerateDownloadTokenResponse
var tokenResponse backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
@@ -620,7 +621,7 @@ func Test_DownloadBackup_TokenUsedOnce_CannotReuseToken(t *testing.T) {
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
// Generate download token
var tokenResponse GenerateDownloadTokenResponse
var tokenResponse backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
@@ -683,7 +684,7 @@ func Test_DownloadBackup_WithDifferentBackupToken_Unauthorized(t *testing.T) {
backup2 := createTestBackup(database2, owner)
// Generate token for backup1
var tokenResponse GenerateDownloadTokenResponse
var tokenResponse backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
@@ -714,7 +715,7 @@ func Test_DownloadBackup_AuditLogWritten(t *testing.T) {
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
// Generate download token
var tokenResponse GenerateDownloadTokenResponse
var tokenResponse backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
@@ -806,7 +807,7 @@ func Test_DownloadBackup_ProperFilenameForPostgreSQL(t *testing.T) {
backup := createTestBackup(database, owner)
// Generate download token
var tokenResponse GenerateDownloadTokenResponse
var tokenResponse backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
@@ -897,22 +898,22 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
_, err = configService.SaveBackupConfig(config)
assert.NoError(t, err)
backup := &Backup{
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusInProgress,
Status: backups_core.BackupStatusInProgress,
BackupSizeMb: 0,
BackupDurationMs: 0,
CreatedAt: time.Now().UTC(),
}
repo := &BackupRepository{}
repo := &backups_core.BackupRepository{}
err = repo.Save(backup)
assert.NoError(t, err)
// Register a cancellable context for the backup
GetBackupService().backupContextManager.RegisterBackup(backup.ID, func() {})
GetBackupService().backupCancelManager.RegisterBackup(backup.ID, func() {})
resp := test_utils.MakePostRequest(
t,
@@ -949,6 +950,189 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
assert.True(t, foundCancelLog, "Cancel audit log should be created")
}
func Test_ConcurrentDownloadPrevention(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
var token1Response backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusOK,
&token1Response,
)
var token2Response backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusOK,
&token2Response,
)
downloadInProgress := make(chan bool, 1)
downloadComplete := make(chan bool, 1)
go func() {
test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf(
"/api/v1/backups/%s/file?token=%s",
backup.ID.String(),
token1Response.Token,
),
"",
http.StatusOK,
)
downloadComplete <- true
}()
time.Sleep(50 * time.Millisecond)
service := GetBackupService()
if !service.IsDownloadInProgress(owner.UserID) {
t.Log("Warning: First download completed before we could test concurrency")
<-downloadComplete
return
}
downloadInProgress <- true
resp := test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), token2Response.Token),
"",
http.StatusConflict,
)
var errorResponse map[string]string
err := json.Unmarshal(resp.Body, &errorResponse)
assert.NoError(t, err)
assert.Contains(t, errorResponse["error"], "download already in progress")
<-downloadComplete
<-downloadInProgress
time.Sleep(100 * time.Millisecond)
var token3Response backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusOK,
&token3Response,
)
test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), token3Response.Token),
"",
http.StatusOK,
)
t.Log("Database:", database.Name)
t.Log(
"Successfully prevented concurrent downloads and allowed subsequent downloads after completion",
)
}
func Test_GenerateDownloadToken_BlockedWhenDownloadInProgress(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
var token1Response backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusOK,
&token1Response,
)
downloadComplete := make(chan bool, 1)
go func() {
test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf(
"/api/v1/backups/%s/file?token=%s",
backup.ID.String(),
token1Response.Token,
),
"",
http.StatusOK,
)
downloadComplete <- true
}()
time.Sleep(50 * time.Millisecond)
service := GetBackupService()
if !service.IsDownloadInProgress(owner.UserID) {
t.Log("Warning: First download completed before we could test token generation blocking")
<-downloadComplete
return
}
resp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusConflict,
)
var errorResponse map[string]string
err := json.Unmarshal(resp.Body, &errorResponse)
assert.NoError(t, err)
assert.Contains(t, errorResponse["error"], "download already in progress")
<-downloadComplete
time.Sleep(100 * time.Millisecond)
var token2Response backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusOK,
&token2Response,
)
assert.NotEmpty(t, token2Response.Token)
assert.NotEqual(t, token1Response.Token, token2Response.Token)
t.Log("Database:", database.Name)
t.Log(
"Successfully blocked token generation during download and allowed generation after completion",
)
}
func createTestRouter() *gin.Engine {
return CreateTestRouter()
}
@@ -1038,7 +1222,7 @@ func createTestDatabaseWithBackups(
workspace *workspaces_models.Workspace,
owner *users_dto.SignInResponseDTO,
router *gin.Engine,
) (*databases.Database, *Backup) {
) (*databases.Database, *backups_core.Backup) {
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
@@ -1064,7 +1248,7 @@ func createTestDatabaseWithBackups(
func createTestBackup(
database *databases.Database,
owner *users_dto.SignInResponseDTO,
) *Backup {
) *backups_core.Backup {
userService := users_services.GetUserService()
user, err := userService.GetUserFromToken(owner.Token)
if err != nil {
@@ -1076,17 +1260,17 @@ func createTestBackup(
panic("No storage found for workspace")
}
backup := &Backup{
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storages[0].ID,
Status: BackupStatusCompleted,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10.5,
BackupDurationMs: 1000,
CreatedAt: time.Now().UTC(),
}
repo := &BackupRepository{}
repo := &backups_core.BackupRepository{}
if err := repo.Save(backup); err != nil {
panic(err)
}
@@ -1116,7 +1300,7 @@ func createExpiredDownloadToken(backupID, userID uuid.UUID) string {
}
// Manually update the token to be expired
repo := &download_token.DownloadTokenRepository{}
repo := &backups_download.DownloadTokenRepository{}
downloadToken, err := repo.FindByToken(token)
if err != nil || downloadToken == nil {
panic(fmt.Sprintf("Failed to find generated token: %v", err))
@@ -1130,3 +1314,267 @@ func createExpiredDownloadToken(backupID, userID uuid.UUID) string {
return token
}
func Test_BandwidthThrottling_SingleDownload_Uses75Percent(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
bandwidthManager := backups_download.GetBandwidthManager()
initialCount := bandwidthManager.GetActiveDownloadCount()
var tokenResponse backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusOK,
&tokenResponse,
)
downloadStarted := make(chan bool, 1)
downloadComplete := make(chan bool, 1)
go func() {
test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf(
"/api/v1/backups/%s/file?token=%s",
backup.ID.String(),
tokenResponse.Token,
),
"",
http.StatusOK,
)
downloadComplete <- true
}()
time.Sleep(50 * time.Millisecond)
activeCount := bandwidthManager.GetActiveDownloadCount()
if activeCount > initialCount {
downloadStarted <- true
assert.Equal(t, initialCount+1, activeCount, "Should have one active download")
}
<-downloadComplete
if len(downloadStarted) > 0 {
<-downloadStarted
}
time.Sleep(50 * time.Millisecond)
finalCount := bandwidthManager.GetActiveDownloadCount()
assert.Equal(t, initialCount, finalCount, "Download should be unregistered after completion")
}
func Test_BandwidthThrottling_MultipleDownloads_ShareBandwidth(t *testing.T) {
router := createTestRouter()
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
owner3 := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner1, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
owner2,
users_enums.WorkspaceRoleMember,
owner1.Token,
router,
)
workspaces_testing.AddMemberToWorkspace(
workspace,
owner3,
users_enums.WorkspaceRoleMember,
owner1.Token,
router,
)
database := createTestDatabase("Test Database", workspace.ID, owner1.Token, router)
storage := createTestStorage(workspace.ID)
configService := backups_config.GetBackupConfigService()
config, err := configService.GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
config.IsBackupsEnabled = true
config.StorageID = &storage.ID
config.Storage = storage
_, err = configService.SaveBackupConfig(config)
assert.NoError(t, err)
backup1 := createTestBackup(database, owner1)
backup2 := createTestBackup(database, owner2)
backup3 := createTestBackup(database, owner3)
var token1, token2, token3 backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup1.ID.String()),
"Bearer "+owner1.Token,
nil,
http.StatusOK,
&token1,
)
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup2.ID.String()),
"Bearer "+owner2.Token,
nil,
http.StatusOK,
&token2,
)
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup3.ID.String()),
"Bearer "+owner3.Token,
nil,
http.StatusOK,
&token3,
)
bandwidthManager := backups_download.GetBandwidthManager()
initialCount := bandwidthManager.GetActiveDownloadCount()
complete1 := make(chan bool, 1)
complete2 := make(chan bool, 1)
complete3 := make(chan bool, 1)
go func() {
test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup1.ID.String(), token1.Token),
"",
http.StatusOK,
)
complete1 <- true
}()
go func() {
test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup2.ID.String(), token2.Token),
"",
http.StatusOK,
)
complete2 <- true
}()
go func() {
test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup3.ID.String(), token3.Token),
"",
http.StatusOK,
)
complete3 <- true
}()
time.Sleep(100 * time.Millisecond)
<-complete1
<-complete2
<-complete3
time.Sleep(100 * time.Millisecond)
finalCount := bandwidthManager.GetActiveDownloadCount()
assert.Equal(t, initialCount, finalCount, "All downloads should be unregistered")
}
func Test_BandwidthThrottling_DynamicAdjustment(t *testing.T) {
router := createTestRouter()
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner1, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
owner2,
users_enums.WorkspaceRoleMember,
owner1.Token,
router,
)
database := createTestDatabase("Test Database", workspace.ID, owner1.Token, router)
storage := createTestStorage(workspace.ID)
configService := backups_config.GetBackupConfigService()
config, err := configService.GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
config.IsBackupsEnabled = true
config.StorageID = &storage.ID
config.Storage = storage
_, err = configService.SaveBackupConfig(config)
assert.NoError(t, err)
backup1 := createTestBackup(database, owner1)
backup2 := createTestBackup(database, owner2)
var token1, token2 backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup1.ID.String()),
"Bearer "+owner1.Token,
nil,
http.StatusOK,
&token1,
)
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup2.ID.String()),
"Bearer "+owner2.Token,
nil,
http.StatusOK,
&token2,
)
bandwidthManager := backups_download.GetBandwidthManager()
initialCount := bandwidthManager.GetActiveDownloadCount()
complete1 := make(chan bool, 1)
complete2 := make(chan bool, 1)
go func() {
test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup1.ID.String(), token1.Token),
"",
http.StatusOK,
)
complete1 <- true
}()
time.Sleep(50 * time.Millisecond)
go func() {
test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup2.ID.String(), token2.Token),
"",
http.StatusOK,
)
complete2 <- true
}()
<-complete1
<-complete2
time.Sleep(100 * time.Millisecond)
finalCount := bandwidthManager.GetActiveDownloadCount()
assert.Equal(t, initialCount, finalCount, "All downloads completed and unregistered")
}

View File

@@ -1,4 +1,4 @@
package backups
package backups_core
type BackupStatus string

View File

@@ -1,4 +1,4 @@
package backups
package backups_core
import (
"context"

View File

@@ -1,4 +1,4 @@
package backups
package backups_core
import (
backups_config "databasus-backend/internal/features/backups/config"

View File

@@ -1,4 +1,4 @@
package backups
package backups_core
import (
"databasus-backend/internal/storage"

View File

@@ -1,10 +1,11 @@
package backups
import (
"time"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups/download_token"
"databasus-backend/internal/features/backups/backups/backuping"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_download "databasus-backend/internal/features/backups/backups/download"
"databasus-backend/internal/features/backups/backups/usecases"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -16,50 +17,31 @@ import (
"databasus-backend/internal/util/logger"
)
var backupRepository = &BackupRepository{}
var backupRepository = &backups_core.BackupRepository{}
var backupContextManager = NewBackupContextManager()
var backupCancelManager = backups_cancellation.GetBackupCancelManager()
var backupService = &BackupService{
databases.GetDatabaseService(),
storages.GetStorageService(),
backupRepository,
notifiers.GetNotifierService(),
notifiers.GetNotifierService(),
backups_config.GetBackupConfigService(),
encryption_secrets.GetSecretKeyService(),
encryption.GetFieldEncryptor(),
usecases.GetCreateBackupUsecase(),
logger.GetLogger(),
[]BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
backupContextManager,
download_token.GetDownloadTokenService(),
}
var backupBackgroundService = &BackupBackgroundService{
backupService,
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
time.Now().UTC(),
logger.GetLogger(),
databaseService: databases.GetDatabaseService(),
storageService: storages.GetStorageService(),
backupRepository: backupRepository,
notifierService: notifiers.GetNotifierService(),
notificationSender: notifiers.GetNotifierService(),
backupConfigService: backups_config.GetBackupConfigService(),
secretKeyService: encryption_secrets.GetSecretKeyService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
createBackupUseCase: usecases.GetCreateBackupUsecase(),
logger: logger.GetLogger(),
backupRemoveListeners: []backups_core.BackupRemoveListener{},
workspaceService: workspaces_services.GetWorkspaceService(),
auditLogService: audit_logs.GetAuditLogService(),
backupCancelManager: backupCancelManager,
downloadTokenService: backups_download.GetDownloadTokenService(),
backupSchedulerService: backuping.GetBackupsScheduler(),
}
var backupController = &BackupController{
backupService,
}
func SetupDependencies() {
backups_config.
GetBackupConfigService().
SetDatabaseStorageChangeListener(backupService)
databases.GetDatabaseService().AddDbRemoveListener(backupService)
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
backupContextManager.StartSubscription()
backupService: backupService,
}
func GetBackupService() *BackupService {
@@ -70,10 +52,11 @@ func GetBackupController() *BackupController {
return backupController
}
func GetBackupBackgroundService() *BackupBackgroundService {
return backupBackgroundService
}
func SetupDependencies() {
backups_config.
GetBackupConfigService().
SetDatabaseStorageChangeListener(backupService)
func GetDownloadTokenBackgroundService() *download_token.DownloadTokenBackgroundService {
return download_token.GetDownloadTokenBackgroundService()
databases.GetDatabaseService().AddDbRemoveListener(backupService)
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
}

View File

@@ -0,0 +1,34 @@
package backups_download
import (
"context"
"log/slog"
"time"
)
type DownloadTokenBackgroundService struct {
downloadTokenService *DownloadTokenService
logger *slog.Logger
}
func (s *DownloadTokenBackgroundService) Run(ctx context.Context) {
s.logger.Info("Starting download token cleanup background service")
if ctx.Err() != nil {
return
}
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
s.logger.Error("Failed to clean expired download tokens", "error", err)
}
}
}
}

View File

@@ -0,0 +1,81 @@
package backups_download
import (
"fmt"
"sync"
"github.com/google/uuid"
)
type BandwidthManager struct {
mu sync.RWMutex
activeDownloads map[uuid.UUID]*activeDownload
maxTotalBytesPerSecond int64
bytesPerSecondPerDownload int64
}
type activeDownload struct {
userID uuid.UUID
rateLimiter *RateLimiter
}
func NewBandwidthManager(throughputMBs int) *BandwidthManager {
// Use 75% of total throughput
maxBytes := int64(throughputMBs) * 1024 * 1024 * 75 / 100
return &BandwidthManager{
activeDownloads: make(map[uuid.UUID]*activeDownload),
maxTotalBytesPerSecond: maxBytes,
bytesPerSecondPerDownload: maxBytes,
}
}
func (bm *BandwidthManager) RegisterDownload(userID uuid.UUID) (*RateLimiter, error) {
bm.mu.Lock()
defer bm.mu.Unlock()
if _, exists := bm.activeDownloads[userID]; exists {
return nil, fmt.Errorf("download already registered for user %s", userID)
}
rateLimiter := NewRateLimiter(bm.bytesPerSecondPerDownload)
bm.activeDownloads[userID] = &activeDownload{
userID: userID,
rateLimiter: rateLimiter,
}
bm.recalculateRates()
return rateLimiter, nil
}
func (bm *BandwidthManager) UnregisterDownload(userID uuid.UUID) {
bm.mu.Lock()
defer bm.mu.Unlock()
delete(bm.activeDownloads, userID)
bm.recalculateRates()
}
func (bm *BandwidthManager) GetActiveDownloadCount() int {
bm.mu.RLock()
defer bm.mu.RUnlock()
return len(bm.activeDownloads)
}
func (bm *BandwidthManager) recalculateRates() {
activeCount := len(bm.activeDownloads)
if activeCount == 0 {
bm.bytesPerSecondPerDownload = bm.maxTotalBytesPerSecond
return
}
newRate := bm.maxTotalBytesPerSecond / int64(activeCount)
bm.bytesPerSecondPerDownload = newRate
for _, download := range bm.activeDownloads {
download.rateLimiter.UpdateRate(newRate)
}
}

View File

@@ -0,0 +1,150 @@
package backups_download
import (
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_BandwidthManager_RegisterSingleDownload(t *testing.T) {
throughputMBs := 100
manager := NewBandwidthManager(throughputMBs)
expectedBytesPerSec := int64(100 * 1024 * 1024 * 75 / 100)
assert.Equal(t, expectedBytesPerSec, manager.maxTotalBytesPerSecond)
assert.Equal(t, expectedBytesPerSec, manager.bytesPerSecondPerDownload)
userID := uuid.New()
rateLimiter, err := manager.RegisterDownload(userID)
assert.NoError(t, err)
assert.NotNil(t, rateLimiter)
assert.Equal(t, 1, manager.GetActiveDownloadCount())
assert.Equal(t, expectedBytesPerSec, manager.bytesPerSecondPerDownload)
assert.Equal(t, expectedBytesPerSec, rateLimiter.bytesPerSecond)
}
func Test_BandwidthManager_RegisterMultipleDownloads_BandwidthShared(t *testing.T) {
throughputMBs := 100
manager := NewBandwidthManager(throughputMBs)
maxBytes := int64(100 * 1024 * 1024 * 75 / 100)
user1 := uuid.New()
rateLimiter1, err := manager.RegisterDownload(user1)
assert.NoError(t, err)
assert.Equal(t, maxBytes, rateLimiter1.bytesPerSecond)
user2 := uuid.New()
rateLimiter2, err := manager.RegisterDownload(user2)
assert.NoError(t, err)
expectedPerDownload := maxBytes / 2
assert.Equal(t, expectedPerDownload, manager.bytesPerSecondPerDownload)
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
assert.Equal(t, expectedPerDownload, rateLimiter2.bytesPerSecond)
assert.Equal(t, expectedPerDownload, rateLimiter2.bytesPerSecond)
user3 := uuid.New()
rateLimiter3, err := manager.RegisterDownload(user3)
assert.NoError(t, err)
expectedPerDownload = maxBytes / 3
assert.Equal(t, expectedPerDownload, manager.bytesPerSecondPerDownload)
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
assert.Equal(t, expectedPerDownload, rateLimiter2.bytesPerSecond)
assert.Equal(t, expectedPerDownload, rateLimiter3.bytesPerSecond)
assert.Equal(t, 3, manager.GetActiveDownloadCount())
}
func Test_BandwidthManager_UnregisterDownload_BandwidthRebalanced(t *testing.T) {
throughputMBs := 100
manager := NewBandwidthManager(throughputMBs)
maxBytes := int64(100 * 1024 * 1024 * 75 / 100)
user1 := uuid.New()
rateLimiter1, _ := manager.RegisterDownload(user1)
user2 := uuid.New()
_, _ = manager.RegisterDownload(user2)
user3 := uuid.New()
rateLimiter3, _ := manager.RegisterDownload(user3)
assert.Equal(t, 3, manager.GetActiveDownloadCount())
expectedPerDownload := maxBytes / 3
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
manager.UnregisterDownload(user2)
assert.Equal(t, 2, manager.GetActiveDownloadCount())
expectedPerDownload = maxBytes / 2
assert.Equal(t, expectedPerDownload, manager.bytesPerSecondPerDownload)
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
assert.Equal(t, expectedPerDownload, rateLimiter3.bytesPerSecond)
manager.UnregisterDownload(user1)
assert.Equal(t, 1, manager.GetActiveDownloadCount())
assert.Equal(t, maxBytes, manager.bytesPerSecondPerDownload)
assert.Equal(t, maxBytes, rateLimiter3.bytesPerSecond)
manager.UnregisterDownload(user3)
assert.Equal(t, 0, manager.GetActiveDownloadCount())
assert.Equal(t, maxBytes, manager.bytesPerSecondPerDownload)
}
func Test_BandwidthManager_RegisterDuplicateUser_ReturnsError(t *testing.T) {
manager := NewBandwidthManager(100)
userID := uuid.New()
_, err := manager.RegisterDownload(userID)
assert.NoError(t, err)
_, err = manager.RegisterDownload(userID)
assert.Error(t, err)
assert.Contains(t, err.Error(), "download already registered")
}
func Test_RateLimiter_TokenBucketBasic(t *testing.T) {
bytesPerSec := int64(1024 * 1024)
limiter := NewRateLimiter(bytesPerSec)
assert.Equal(t, bytesPerSec, limiter.bytesPerSecond)
assert.Equal(t, bytesPerSec*2, limiter.bucketSize)
start := time.Now()
limiter.Wait(512 * 1024)
elapsed := time.Since(start)
assert.Less(t, elapsed, 100*time.Millisecond)
}
func Test_RateLimiter_UpdateRate(t *testing.T) {
limiter := NewRateLimiter(1024 * 1024)
assert.Equal(t, int64(1024*1024), limiter.bytesPerSecond)
newRate := int64(2 * 1024 * 1024)
limiter.UpdateRate(newRate)
assert.Equal(t, newRate, limiter.bytesPerSecond)
assert.Equal(t, newRate*2, limiter.bucketSize)
}
func Test_RateLimiter_ThrottlesCorrectly(t *testing.T) {
bytesPerSec := int64(1024 * 1024)
limiter := NewRateLimiter(bytesPerSec)
limiter.availableTokens = 0
start := time.Now()
limiter.Wait(bytesPerSec / 2)
elapsed := time.Since(start)
assert.GreaterOrEqual(t, elapsed, 400*time.Millisecond)
assert.LessOrEqual(t, elapsed, 700*time.Millisecond)
}

View File

@@ -0,0 +1,48 @@
package backups_download
import (
"databasus-backend/internal/config"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/logger"
)
var downloadTokenRepository = &DownloadTokenRepository{}
var downloadTracker = NewDownloadTracker(cache_utils.GetValkeyClient())
var bandwidthManager *BandwidthManager
var downloadTokenService *DownloadTokenService
var downloadTokenBackgroundService *DownloadTokenBackgroundService
func init() {
env := config.GetEnv()
throughputMBs := env.NodeNetworkThroughputMBs
if throughputMBs == 0 {
throughputMBs = 125
}
bandwidthManager = NewBandwidthManager(throughputMBs)
downloadTokenService = &DownloadTokenService{
downloadTokenRepository,
logger.GetLogger(),
downloadTracker,
bandwidthManager,
}
downloadTokenBackgroundService = &DownloadTokenBackgroundService{
downloadTokenService,
logger.GetLogger(),
}
}
func GetDownloadTokenService() *DownloadTokenService {
return downloadTokenService
}
func GetDownloadTokenBackgroundService() *DownloadTokenBackgroundService {
return downloadTokenBackgroundService
}
func GetBandwidthManager() *BandwidthManager {
return bandwidthManager
}

View File

@@ -0,0 +1,9 @@
package backups_download
import "github.com/google/uuid"
type GenerateDownloadTokenResponse struct {
Token string `json:"token"`
Filename string `json:"filename"`
BackupID uuid.UUID `json:"backupId"`
}

View File

@@ -1,4 +1,4 @@
package download_token
package backups_download
import (
"time"

View File

@@ -0,0 +1,101 @@
package backups_download
import (
"io"
"sync"
"time"
)
type RateLimiter struct {
mu sync.Mutex
bytesPerSecond int64
bucketSize int64
availableTokens float64
lastRefill time.Time
}
func NewRateLimiter(bytesPerSecond int64) *RateLimiter {
if bytesPerSecond <= 0 {
bytesPerSecond = 1024 * 1024 * 100
}
return &RateLimiter{
bytesPerSecond: bytesPerSecond,
bucketSize: bytesPerSecond * 2,
availableTokens: float64(bytesPerSecond * 2),
lastRefill: time.Now().UTC(),
}
}
func (rl *RateLimiter) UpdateRate(bytesPerSecond int64) {
rl.mu.Lock()
defer rl.mu.Unlock()
if bytesPerSecond <= 0 {
bytesPerSecond = 1024 * 1024 * 100
}
rl.bytesPerSecond = bytesPerSecond
rl.bucketSize = bytesPerSecond * 2
if rl.availableTokens > float64(rl.bucketSize) {
rl.availableTokens = float64(rl.bucketSize)
}
}
func (rl *RateLimiter) Wait(bytes int64) {
rl.mu.Lock()
defer rl.mu.Unlock()
for {
now := time.Now().UTC()
elapsed := now.Sub(rl.lastRefill).Seconds()
tokensToAdd := elapsed * float64(rl.bytesPerSecond)
rl.availableTokens += tokensToAdd
if rl.availableTokens > float64(rl.bucketSize) {
rl.availableTokens = float64(rl.bucketSize)
}
rl.lastRefill = now
if rl.availableTokens >= float64(bytes) {
rl.availableTokens -= float64(bytes)
return
}
tokensNeeded := float64(bytes) - rl.availableTokens
waitTime := time.Duration(tokensNeeded/float64(rl.bytesPerSecond)*1000) * time.Millisecond
if waitTime < time.Millisecond {
waitTime = time.Millisecond
}
rl.mu.Unlock()
time.Sleep(waitTime)
rl.mu.Lock()
}
}
type RateLimitedReader struct {
reader io.ReadCloser
rateLimiter *RateLimiter
}
func NewRateLimitedReader(reader io.ReadCloser, limiter *RateLimiter) *RateLimitedReader {
return &RateLimitedReader{
reader: reader,
rateLimiter: limiter,
}
}
func (r *RateLimitedReader) Read(p []byte) (n int, err error) {
n, err = r.reader.Read(p)
if n > 0 {
r.rateLimiter.Wait(int64(n))
}
return n, err
}
func (r *RateLimitedReader) Close() error {
return r.reader.Close()
}

View File

@@ -1,4 +1,4 @@
package download_token
package backups_download
import (
"crypto/rand"

View File

@@ -0,0 +1,105 @@
package backups_download
import (
"errors"
"log/slog"
"time"
"github.com/google/uuid"
)
type DownloadTokenService struct {
repository *DownloadTokenRepository
logger *slog.Logger
downloadTracker *DownloadTracker
bandwidthManager *BandwidthManager
}
func (s *DownloadTokenService) Generate(backupID, userID uuid.UUID) (string, error) {
if s.downloadTracker.IsDownloadInProgress(userID) {
return "", ErrDownloadAlreadyInProgress
}
token := GenerateSecureToken()
downloadToken := &DownloadToken{
Token: token,
BackupID: backupID,
UserID: userID,
ExpiresAt: time.Now().UTC().Add(5 * time.Minute),
Used: false,
}
if err := s.repository.Create(downloadToken); err != nil {
return "", err
}
s.logger.Info("Generated download token", "backupId", backupID, "userId", userID)
return token, nil
}
func (s *DownloadTokenService) ValidateAndConsume(
token string,
) (*DownloadToken, *RateLimiter, error) {
dt, err := s.repository.FindByToken(token)
if err != nil {
return nil, nil, err
}
if dt == nil {
return nil, nil, errors.New("invalid token")
}
if dt.Used {
return nil, nil, errors.New("token already used")
}
if time.Now().UTC().After(dt.ExpiresAt) {
return nil, nil, errors.New("token expired")
}
if err := s.downloadTracker.AcquireDownloadLock(dt.UserID); err != nil {
return nil, nil, err
}
rateLimiter, err := s.bandwidthManager.RegisterDownload(dt.UserID)
if err != nil {
s.downloadTracker.ReleaseDownloadLock(dt.UserID)
return nil, nil, err
}
dt.Used = true
if err := s.repository.Update(dt); err != nil {
s.logger.Error("Failed to mark token as used", "error", err)
}
s.logger.Info("Token validated and consumed", "backupId", dt.BackupID, "userId", dt.UserID)
return dt, rateLimiter, nil
}
func (s *DownloadTokenService) RefreshDownloadLock(userID uuid.UUID) {
s.downloadTracker.RefreshDownloadLock(userID)
}
func (s *DownloadTokenService) ReleaseDownloadLock(userID uuid.UUID) {
s.downloadTracker.ReleaseDownloadLock(userID)
s.logger.Info("Released download lock", "userId", userID)
}
func (s *DownloadTokenService) IsDownloadInProgress(userID uuid.UUID) bool {
return s.downloadTracker.IsDownloadInProgress(userID)
}
func (s *DownloadTokenService) UnregisterDownload(userID uuid.UUID) {
s.bandwidthManager.UnregisterDownload(userID)
s.logger.Info("Unregistered from bandwidth manager", "userId", userID)
}
func (s *DownloadTokenService) CleanExpiredTokens() error {
now := time.Now().UTC()
if err := s.repository.DeleteExpired(now); err != nil {
return err
}
s.logger.Debug("Cleaned expired download tokens")
return nil
}

View File

@@ -0,0 +1,66 @@
package backups_download
import (
cache_utils "databasus-backend/internal/util/cache"
"errors"
"time"
"github.com/google/uuid"
"github.com/valkey-io/valkey-go"
)
const (
downloadLockPrefix = "backup_download_lock:"
downloadLockTTL = 5 * time.Second
downloadLockValue = "1"
downloadHeartbeatDelay = 3 * time.Second
)
var (
ErrDownloadAlreadyInProgress = errors.New("download already in progress for this user")
)
type DownloadTracker struct {
cache *cache_utils.CacheUtil[string]
}
func NewDownloadTracker(client valkey.Client) *DownloadTracker {
return &DownloadTracker{
cache: cache_utils.NewCacheUtil[string](client, downloadLockPrefix),
}
}
func (t *DownloadTracker) AcquireDownloadLock(userID uuid.UUID) error {
key := userID.String()
existingLock := t.cache.Get(key)
if existingLock != nil {
return ErrDownloadAlreadyInProgress
}
value := downloadLockValue
t.cache.Set(key, &value)
return nil
}
func (t *DownloadTracker) RefreshDownloadLock(userID uuid.UUID) {
key := userID.String()
value := downloadLockValue
t.cache.Set(key, &value)
}
func (t *DownloadTracker) ReleaseDownloadLock(userID uuid.UUID) {
key := userID.String()
t.cache.Invalidate(key)
}
func (t *DownloadTracker) IsDownloadInProgress(userID uuid.UUID) bool {
key := userID.String()
existingLock := t.cache.Get(key)
return existingLock != nil
}
func GetDownloadHeartbeatInterval() time.Duration {
return downloadHeartbeatDelay
}

View File

@@ -1,32 +0,0 @@
package download_token
import (
"databasus-backend/internal/config"
"log/slog"
"time"
)
type DownloadTokenBackgroundService struct {
downloadTokenService *DownloadTokenService
logger *slog.Logger
}
func (s *DownloadTokenBackgroundService) Run() {
s.logger.Info("Starting download token cleanup background service")
if config.IsShouldShutdown() {
return
}
for {
if config.IsShouldShutdown() {
return
}
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
s.logger.Error("Failed to clean expired download tokens", "error", err)
}
time.Sleep(1 * time.Minute)
}
}

View File

@@ -1,25 +0,0 @@
package download_token
import (
"databasus-backend/internal/util/logger"
)
var downloadTokenRepository = &DownloadTokenRepository{}
var downloadTokenService = &DownloadTokenService{
downloadTokenRepository,
logger.GetLogger(),
}
var downloadTokenBackgroundService = &DownloadTokenBackgroundService{
downloadTokenService,
logger.GetLogger(),
}
func GetDownloadTokenService() *DownloadTokenService {
return downloadTokenService
}
func GetDownloadTokenBackgroundService() *DownloadTokenBackgroundService {
return downloadTokenBackgroundService
}

View File

@@ -1,69 +0,0 @@
package download_token
import (
"errors"
"log/slog"
"time"
"github.com/google/uuid"
)
type DownloadTokenService struct {
repository *DownloadTokenRepository
logger *slog.Logger
}
func (s *DownloadTokenService) Generate(backupID, userID uuid.UUID) (string, error) {
token := GenerateSecureToken()
downloadToken := &DownloadToken{
Token: token,
BackupID: backupID,
UserID: userID,
ExpiresAt: time.Now().UTC().Add(5 * time.Minute),
Used: false,
}
if err := s.repository.Create(downloadToken); err != nil {
return "", err
}
s.logger.Info("Generated download token", "backupId", backupID, "userId", userID)
return token, nil
}
func (s *DownloadTokenService) ValidateAndConsume(token string) (*DownloadToken, error) {
dt, err := s.repository.FindByToken(token)
if err != nil {
return nil, err
}
if dt == nil {
return nil, errors.New("invalid token")
}
if dt.Used {
return nil, errors.New("token already used")
}
if time.Now().UTC().After(dt.ExpiresAt) {
return nil, errors.New("token expired")
}
dt.Used = true
if err := s.repository.Update(dt); err != nil {
s.logger.Error("Failed to mark token as used", "error", err)
}
s.logger.Info("Token validated and consumed", "backupId", dt.BackupID)
return dt, nil
}
func (s *DownloadTokenService) CleanExpiredTokens() error {
now := time.Now().UTC()
if err := s.repository.DeleteExpired(now); err != nil {
return err
}
s.logger.Debug("Cleaned expired download tokens")
return nil
}

View File

@@ -1,10 +1,9 @@
package backups
import (
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/encryption"
"io"
"github.com/google/uuid"
)
type GetBackupsRequest struct {
@@ -14,23 +13,17 @@ type GetBackupsRequest struct {
}
type GetBackupsResponse struct {
Backups []*Backup `json:"backups"`
Total int64 `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
Backups []*backups_core.Backup `json:"backups"`
Total int64 `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
}
type GenerateDownloadTokenResponse struct {
Token string `json:"token"`
Filename string `json:"filename"`
BackupID uuid.UUID `json:"backupId"`
}
type decryptionReaderCloser struct {
type DecryptionReaderCloser struct {
*encryption.DecryptionReader
baseReader io.ReadCloser
BaseReader io.ReadCloser
}
func (r *decryptionReaderCloser) Close() error {
return r.baseReader.Close()
func (r *DecryptionReaderCloser) Close() error {
return r.BaseReader.Close()
}

View File

@@ -1,18 +1,17 @@
package backups
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"log/slog"
"slices"
"strings"
"time"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups/download_token"
"databasus-backend/internal/features/backups/backups/backuping"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_download "databasus-backend/internal/features/backups/backups/download"
"databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -29,26 +28,27 @@ import (
type BackupService struct {
databaseService *databases.DatabaseService
storageService *storages.StorageService
backupRepository *BackupRepository
backupRepository *backups_core.BackupRepository
notifierService *notifiers.NotifierService
notificationSender NotificationSender
notificationSender backups_core.NotificationSender
backupConfigService *backups_config.BackupConfigService
secretKeyService *encryption_secrets.SecretKeyService
fieldEncryptor util_encryption.FieldEncryptor
createBackupUseCase CreateBackupUsecase
createBackupUseCase backups_core.CreateBackupUsecase
logger *slog.Logger
backupRemoveListeners []BackupRemoveListener
backupRemoveListeners []backups_core.BackupRemoveListener
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
backupContextManager *BackupContextManager
downloadTokenService *download_token.DownloadTokenService
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
backupCancelManager *backups_cancellation.BackupCancelManager
downloadTokenService *backups_download.DownloadTokenService
backupSchedulerService *backuping.BackupsScheduler
}
func (s *BackupService) AddBackupRemoveListener(listener BackupRemoveListener) {
func (s *BackupService) AddBackupRemoveListener(listener backups_core.BackupRemoveListener) {
s.backupRemoveListeners = append(s.backupRemoveListeners, listener)
}
@@ -91,7 +91,7 @@ func (s *BackupService) MakeBackupWithAuth(
return errors.New("insufficient permissions to create backup for this database")
}
go s.MakeBackup(databaseID, true)
s.backupSchedulerService.StartBackup(databaseID, true)
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Backup manually initiated for database: %s", database.Name),
@@ -175,7 +175,7 @@ func (s *BackupService) DeleteBackup(
return errors.New("insufficient permissions to delete backup for this database")
}
if backup.Status == BackupStatusInProgress {
if backup.Status == backups_core.BackupStatusInProgress {
return errors.New("backup is in progress")
}
@@ -192,260 +192,7 @@ func (s *BackupService) DeleteBackup(
return s.deleteBackup(backup)
}
func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
database, err := s.databaseService.GetDatabaseByID(databaseID)
if err != nil {
s.logger.Error("Failed to get database by ID", "error", err)
return
}
lastBackup, err := s.backupRepository.FindLastByDatabaseID(databaseID)
if err != nil {
s.logger.Error("Failed to find last backup by database ID", "error", err)
return
}
if lastBackup != nil && lastBackup.Status == BackupStatusInProgress {
s.logger.Error("Backup is in progress")
return
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(databaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
return
}
if backupConfig.StorageID == nil {
s.logger.Error("Backup config storage ID is not defined")
return
}
storage, err := s.storageService.GetStorageByID(*backupConfig.StorageID)
if err != nil {
s.logger.Error("Failed to get storage by ID", "error", err)
return
}
backup := &Backup{
DatabaseID: databaseID,
StorageID: storage.ID,
Status: BackupStatusInProgress,
BackupSizeMb: 0,
CreatedAt: time.Now().UTC(),
}
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error("Failed to save backup", "error", err)
return
}
start := time.Now().UTC()
backupProgressListener := func(
completedMBs float64,
) {
backup.BackupSizeMb = completedMBs
backup.BackupDurationMs = time.Since(start).Milliseconds()
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error("Failed to update backup progress", "error", err)
}
}
ctx, cancel := context.WithCancel(context.Background())
s.backupContextManager.RegisterBackup(backup.ID, cancel)
defer s.backupContextManager.UnregisterBackup(backup.ID)
backupMetadata, err := s.createBackupUseCase.Execute(
ctx,
backup.ID,
backupConfig,
database,
storage,
backupProgressListener,
)
if err != nil {
errMsg := err.Error()
// Check if backup was cancelled (not due to shutdown)
isCancelled := strings.Contains(errMsg, "backup cancelled") ||
strings.Contains(errMsg, "context canceled") ||
errors.Is(err, context.Canceled)
isShutdown := strings.Contains(errMsg, "shutdown")
if isCancelled && !isShutdown {
backup.Status = BackupStatusCanceled
backup.BackupDurationMs = time.Since(start).Milliseconds()
backup.BackupSizeMb = 0
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error("Failed to save cancelled backup", "error", err)
}
// Delete partial backup from storage
storage, storageErr := s.storageService.GetStorageByID(backup.StorageID)
if storageErr == nil {
if deleteErr := storage.DeleteFile(s.fieldEncryptor, backup.ID); deleteErr != nil {
s.logger.Error(
"Failed to delete partial backup file",
"backupId",
backup.ID,
"error",
deleteErr,
)
}
}
return
}
backup.FailMessage = &errMsg
backup.Status = BackupStatusFailed
backup.BackupDurationMs = time.Since(start).Milliseconds()
backup.BackupSizeMb = 0
if updateErr := s.databaseService.SetBackupError(databaseID, errMsg); updateErr != nil {
s.logger.Error(
"Failed to update database last backup time",
"databaseId",
databaseID,
"error",
updateErr,
)
}
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error("Failed to save backup", "error", err)
}
s.SendBackupNotification(
backupConfig,
backup,
backups_config.NotificationBackupFailed,
&errMsg,
)
return
}
backup.Status = BackupStatusCompleted
backup.BackupDurationMs = time.Since(start).Milliseconds()
// Update backup with encryption metadata if provided
if backupMetadata != nil {
backup.EncryptionSalt = backupMetadata.EncryptionSalt
backup.EncryptionIV = backupMetadata.EncryptionIV
backup.Encryption = backupMetadata.Encryption
}
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error("Failed to save backup", "error", err)
return
}
// Update database last backup time
now := time.Now().UTC()
if updateErr := s.databaseService.SetLastBackupTime(databaseID, now); updateErr != nil {
s.logger.Error(
"Failed to update database last backup time",
"databaseId",
databaseID,
"error",
updateErr,
)
}
if backup.Status != BackupStatusCompleted && !isLastTry {
return
}
s.SendBackupNotification(
backupConfig,
backup,
backups_config.NotificationBackupSuccess,
nil,
)
}
func (s *BackupService) SendBackupNotification(
backupConfig *backups_config.BackupConfig,
backup *Backup,
notificationType backups_config.BackupNotificationType,
errorMessage *string,
) {
database, err := s.databaseService.GetDatabaseByID(backupConfig.DatabaseID)
if err != nil {
return
}
workspace, err := s.workspaceService.GetWorkspaceByID(*database.WorkspaceID)
if err != nil {
return
}
for _, notifier := range database.Notifiers {
if !slices.Contains(
backupConfig.SendNotificationsOn,
notificationType,
) {
continue
}
title := ""
switch notificationType {
case backups_config.NotificationBackupFailed:
title = fmt.Sprintf(
"❌ Backup failed for database \"%s\" (workspace \"%s\")",
database.Name,
workspace.Name,
)
case backups_config.NotificationBackupSuccess:
title = fmt.Sprintf(
"✅ Backup completed for database \"%s\" (workspace \"%s\")",
database.Name,
workspace.Name,
)
}
message := ""
if errorMessage != nil {
message = *errorMessage
} else {
// Format size conditionally
var sizeStr string
if backup.BackupSizeMb < 1024 {
sizeStr = fmt.Sprintf("%.2f MB", backup.BackupSizeMb)
} else {
sizeGB := backup.BackupSizeMb / 1024
sizeStr = fmt.Sprintf("%.2f GB", sizeGB)
}
// Format duration as "0m 0s 0ms"
totalMs := backup.BackupDurationMs
minutes := totalMs / (1000 * 60)
seconds := (totalMs % (1000 * 60)) / 1000
durationStr := fmt.Sprintf("%dm %ds", minutes, seconds)
message = fmt.Sprintf(
"Backup completed successfully in %s.\nCompressed backup size: %s",
durationStr,
sizeStr,
)
}
s.notificationSender.SendNotification(
&notifier,
title,
message,
)
}
}
func (s *BackupService) GetBackup(backupID uuid.UUID) (*Backup, error) {
func (s *BackupService) GetBackup(backupID uuid.UUID) (*backups_core.Backup, error) {
return s.backupRepository.FindByID(backupID)
}
@@ -475,11 +222,11 @@ func (s *BackupService) CancelBackup(
return errors.New("insufficient permissions to cancel backup for this database")
}
if backup.Status != BackupStatusInProgress {
if backup.Status != backups_core.BackupStatusInProgress {
return errors.New("backup is not in progress")
}
if err := s.backupContextManager.CancelBackup(backupID); err != nil {
if err := s.backupCancelManager.CancelBackup(backupID); err != nil {
return err
}
@@ -499,7 +246,7 @@ func (s *BackupService) CancelBackup(
func (s *BackupService) GetBackupFile(
user *users_models.User,
backupID uuid.UUID,
) (io.ReadCloser, *Backup, *databases.Database, error) {
) (io.ReadCloser, *backups_core.Backup, *databases.Database, error) {
backup, err := s.backupRepository.FindByID(backupID)
if err != nil {
return nil, nil, nil, err
@@ -545,7 +292,7 @@ func (s *BackupService) GetBackupFile(
return reader, backup, database, nil
}
func (s *BackupService) deleteBackup(backup *Backup) error {
func (s *BackupService) deleteBackup(backup *backups_core.Backup) error {
for _, listener := range s.backupRemoveListeners {
if err := listener.OnBeforeBackupRemove(backup); err != nil {
return err
@@ -571,7 +318,7 @@ func (s *BackupService) deleteBackup(backup *Backup) error {
func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
dbBackupsInProgress, err := s.backupRepository.FindByDatabaseIdAndStatus(
databaseID,
BackupStatusInProgress,
backups_core.BackupStatusInProgress,
)
if err != nil {
return err
@@ -680,16 +427,16 @@ func (s *BackupService) getBackupReader(backupID uuid.UUID) (io.ReadCloser, erro
s.logger.Info("Returning encrypted backup with decryption", "backupId", backupID)
return &decryptionReaderCloser{
decryptionReader,
fileReader,
return &DecryptionReaderCloser{
DecryptionReader: decryptionReader,
BaseReader: fileReader,
}, nil
}
func (s *BackupService) GenerateDownloadToken(
user *users_models.User,
backupID uuid.UUID,
) (*GenerateDownloadTokenResponse, error) {
) (*backups_download.GenerateDownloadTokenResponse, error) {
backup, err := s.backupRepository.FindByID(backupID)
if err != nil {
return nil, err
@@ -725,20 +472,22 @@ func (s *BackupService) GenerateDownloadToken(
database.WorkspaceID,
)
return &GenerateDownloadTokenResponse{
return &backups_download.GenerateDownloadTokenResponse{
Token: token,
Filename: filename,
BackupID: backupID,
}, nil
}
func (s *BackupService) ValidateDownloadToken(token string) (*download_token.DownloadToken, error) {
func (s *BackupService) ValidateDownloadToken(
token string,
) (*backups_download.DownloadToken, *backups_download.RateLimiter, error) {
return s.downloadTokenService.ValidateAndConsume(token)
}
func (s *BackupService) GetBackupFileWithoutAuth(
backupID uuid.UUID,
) (io.ReadCloser, *Backup, *databases.Database, error) {
) (io.ReadCloser, *backups_core.Backup, *databases.Database, error) {
backup, err := s.backupRepository.FindByID(backupID)
if err != nil {
return nil, nil, nil, err
@@ -759,7 +508,7 @@ func (s *BackupService) GetBackupFileWithoutAuth(
func (s *BackupService) WriteAuditLogForDownload(
userID uuid.UUID,
backup *Backup,
backup *backups_core.Backup,
database *databases.Database,
) {
s.auditLogService.WriteAuditLog(
@@ -773,8 +522,24 @@ func (s *BackupService) WriteAuditLogForDownload(
)
}
func (s *BackupService) RefreshDownloadLock(userID uuid.UUID) {
s.downloadTokenService.RefreshDownloadLock(userID)
}
func (s *BackupService) ReleaseDownloadLock(userID uuid.UUID) {
s.downloadTokenService.ReleaseDownloadLock(userID)
}
func (s *BackupService) IsDownloadInProgress(userID uuid.UUID) bool {
return s.downloadTokenService.IsDownloadInProgress(userID)
}
func (s *BackupService) UnregisterDownload(userID uuid.UUID) {
s.downloadTokenService.UnregisterDownload(userID)
}
func (s *BackupService) generateBackupFilename(
backup *Backup,
backup *backups_core.Backup,
database *databases.Database,
) string {
timestamp := backup.CreatedAt.Format("2006-01-02_15-04-05")

View File

@@ -4,6 +4,7 @@ import (
"testing"
"time"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
@@ -58,9 +59,9 @@ func WaitForBackupCompletion(
newestBackup := backups[0]
t.Logf("WaitForBackupCompletion: newest backup status: %s", newestBackup.Status)
if newestBackup.Status == BackupStatusCompleted ||
newestBackup.Status == BackupStatusFailed ||
newestBackup.Status == BackupStatusCanceled {
if newestBackup.Status == backups_core.BackupStatusCompleted ||
newestBackup.Status == backups_core.BackupStatusFailed ||
newestBackup.Status == backups_core.BackupStatusCanceled {
t.Logf(
"WaitForBackupCompletion: backup finished with status %s",
newestBackup.Status,

View File

@@ -515,11 +515,13 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
hasProcess := false
hasAllPrivileges := false
escapedDB := strings.ReplaceAll(database, "_", "\\_")
dbPattern := regexp.MustCompile(
fmt.Sprintf("(?i)ON\\s+[`'\"]?(%s|\\*)[`'\"]?\\.\\*", regexp.QuoteMeta(escapedDB)),
dbPatternStr := fmt.Sprintf(
`(?i)ON\s+[\x60'"]?%s[\x60'"]?\s*\.\s*\*`,
regexp.QuoteMeta(database),
)
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\.\*`)
dbPattern := regexp.MustCompile(dbPatternStr)
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\s*\.\s*\*`)
allPrivilegesPattern := regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`)
for rows.Next() {
var grant string
@@ -527,23 +529,26 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
return "", fmt.Errorf("failed to scan grant: %w", err)
}
if regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`).MatchString(grant) {
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
hasAllPrivileges = true
}
isRelevantGrant := globalPattern.MatchString(grant) || dbPattern.MatchString(grant)
if allPrivilegesPattern.MatchString(grant) && isRelevantGrant {
hasAllPrivileges = true
}
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
if isRelevantGrant {
for _, priv := range backupPrivileges {
if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) {
privPattern := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(priv) + `\b`)
if privPattern.MatchString(grant) {
detectedPrivileges[priv] = true
}
}
}
if globalPattern.MatchString(grant) &&
regexp.MustCompile(`(?i)\bPROCESS\b`).MatchString(grant) {
hasProcess = true
if globalPattern.MatchString(grant) {
processPattern := regexp.MustCompile(`(?i)\bPROCESS\b`)
if processPattern.MatchString(grant) {
hasProcess = true
}
}
}

View File

@@ -537,6 +537,163 @@ func Test_ReadOnlyUser_CannotDropOrAlterTables(t *testing.T) {
dropUserSafe(container.DB, username)
}
func Test_TestConnection_DatabaseSpecificPrivilegesWithGlobalProcess_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MariadbVersion
port string
}{
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMariadbContainer(t, tc.port, tc.version)
defer container.DB.Close()
_, err := container.DB.Exec(`DROP TABLE IF EXISTS privilege_test`)
assert.NoError(t, err)
_, err = container.DB.Exec(`CREATE TABLE privilege_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)`)
assert.NoError(t, err)
_, err = container.DB.Exec(`INSERT INTO privilege_test (data) VALUES ('test1')`)
assert.NoError(t, err)
specificUsername := fmt.Sprintf("spec_%s", uuid.New().String()[:8])
specificPassword := "specificpass123"
_, err = container.DB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
specificUsername,
specificPassword,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
"GRANT SELECT, SHOW VIEW ON %s.* TO '%s'@'%%'",
container.Database,
specificUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
"GRANT PROCESS ON *.* TO '%s'@'%%'",
specificUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer dropUserSafe(container.DB, specificUsername)
mariadbModel := &MariadbDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: specificUsername,
Password: specificPassword,
Database: &container.Database,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mariadbModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
})
}
}
func Test_TestConnection_DatabaseWithUnderscores_Success(t *testing.T) {
env := config.GetEnv()
container := connectToMariadbContainer(t, env.TestMariadb1011Port, tools.MariadbVersion1011)
defer container.DB.Close()
underscoreDbName := "test_db_name"
_, err := container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", underscoreDbName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName))
}()
underscoreDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
container.Username, container.Password, container.Host, container.Port, underscoreDbName)
underscoreDB, err := sqlx.Connect("mysql", underscoreDSN)
assert.NoError(t, err)
defer underscoreDB.Close()
_, err = underscoreDB.Exec(`
CREATE TABLE underscore_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)
`)
assert.NoError(t, err)
_, err = underscoreDB.Exec(`INSERT INTO underscore_test (data) VALUES ('test1')`)
assert.NoError(t, err)
underscoreUsername := fmt.Sprintf("under%s", uuid.New().String()[:8])
underscorePassword := "underscorepass123"
_, err = underscoreDB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
underscoreUsername,
underscorePassword,
))
assert.NoError(t, err)
_, err = underscoreDB.Exec(fmt.Sprintf(
"GRANT SELECT, SHOW VIEW ON `%s`.* TO '%s'@'%%'",
underscoreDbName,
underscoreUsername,
))
assert.NoError(t, err)
_, err = underscoreDB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer dropUserSafe(underscoreDB, underscoreUsername)
mariadbModel := &MariadbDatabase{
Version: tools.MariadbVersion1011,
Host: container.Host,
Port: container.Port,
Username: underscoreUsername,
Password: underscorePassword,
Database: &underscoreDbName,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mariadbModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
}
type MariadbContainer struct {
Host string
Port int

View File

@@ -486,11 +486,13 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
hasProcess := false
hasAllPrivileges := false
escapedDB := strings.ReplaceAll(database, "_", "\\_")
dbPattern := regexp.MustCompile(
fmt.Sprintf("(?i)ON\\s+[`'\"]?(%s|\\*)[`'\"]?\\.\\*", regexp.QuoteMeta(escapedDB)),
dbPatternStr := fmt.Sprintf(
`(?i)ON\s+[\x60'"]?%s[\x60'"]?\s*\.\s*\*`,
regexp.QuoteMeta(database),
)
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\.\*`)
dbPattern := regexp.MustCompile(dbPatternStr)
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\s*\.\s*\*`)
allPrivilegesPattern := regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`)
for rows.Next() {
var grant string
@@ -498,23 +500,26 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
return "", fmt.Errorf("failed to scan grant: %w", err)
}
if regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`).MatchString(grant) {
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
hasAllPrivileges = true
}
isRelevantGrant := globalPattern.MatchString(grant) || dbPattern.MatchString(grant)
if allPrivilegesPattern.MatchString(grant) && isRelevantGrant {
hasAllPrivileges = true
}
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
if isRelevantGrant {
for _, priv := range backupPrivileges {
if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) {
privPattern := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(priv) + `\b`)
if privPattern.MatchString(grant) {
detectedPrivileges[priv] = true
}
}
}
if globalPattern.MatchString(grant) &&
regexp.MustCompile(`(?i)\bPROCESS\b`).MatchString(grant) {
hasProcess = true
if globalPattern.MatchString(grant) {
processPattern := regexp.MustCompile(`(?i)\bPROCESS\b`)
if processPattern.MatchString(grant) {
hasProcess = true
}
}
}

View File

@@ -518,6 +518,162 @@ func Test_ReadOnlyUser_CannotDropOrAlterTables(t *testing.T) {
assert.NoError(t, err)
}
func Test_TestConnection_DatabaseSpecificPrivilegesWithGlobalProcess_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MysqlVersion
port string
}{
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port},
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port},
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port},
{"MySQL 9", tools.MysqlVersion9, env.TestMysql90Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMysqlContainer(t, tc.port, tc.version)
defer container.DB.Close()
_, err := container.DB.Exec(`DROP TABLE IF EXISTS privilege_test`)
assert.NoError(t, err)
_, err = container.DB.Exec(`CREATE TABLE privilege_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)`)
assert.NoError(t, err)
_, err = container.DB.Exec(`INSERT INTO privilege_test (data) VALUES ('test1')`)
assert.NoError(t, err)
specificUsername := fmt.Sprintf("specific_%s", uuid.New().String()[:8])
specificPassword := "specificpass123"
_, err = container.DB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
specificUsername,
specificPassword,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
"GRANT SELECT, SHOW VIEW ON %s.* TO '%s'@'%%'",
container.Database,
specificUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
"GRANT PROCESS ON *.* TO '%s'@'%%'",
specificUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(
fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", specificUsername),
)
}()
mysqlModel := &MysqlDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: specificUsername,
Password: specificPassword,
Database: &container.Database,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mysqlModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
})
}
}
func Test_TestConnection_DatabaseWithUnderscores_Success(t *testing.T) {
env := config.GetEnv()
container := connectToMysqlContainer(t, env.TestMysql80Port, tools.MysqlVersion80)
defer container.DB.Close()
underscoreDbName := "test_db_name"
_, err := container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", underscoreDbName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName))
}()
underscoreDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
container.Username, container.Password, container.Host, container.Port, underscoreDbName)
underscoreDB, err := sqlx.Connect("mysql", underscoreDSN)
assert.NoError(t, err)
defer underscoreDB.Close()
_, err = underscoreDB.Exec(`
CREATE TABLE underscore_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)
`)
assert.NoError(t, err)
_, err = underscoreDB.Exec(`INSERT INTO underscore_test (data) VALUES ('test1')`)
assert.NoError(t, err)
underscoreUsername := fmt.Sprintf("under_%s", uuid.New().String()[:8])
underscorePassword := "underscorepass123"
_, err = underscoreDB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
underscoreUsername,
underscorePassword,
))
assert.NoError(t, err)
_, err = underscoreDB.Exec(fmt.Sprintf(
"GRANT SELECT, SHOW VIEW ON `%s`.* TO '%s'@'%%'",
underscoreDbName,
underscoreUsername,
))
assert.NoError(t, err)
_, err = underscoreDB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer func() {
_, _ = underscoreDB.Exec(fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", underscoreUsername))
}()
mysqlModel := &MysqlDatabase{
Version: tools.MysqlVersion80,
Host: container.Host,
Port: container.Port,
Username: underscoreUsername,
Password: underscorePassword,
Database: &underscoreDbName,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mysqlModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
}
type MysqlContainer struct {
Host string
Port int

View File

@@ -761,12 +761,15 @@ func checkBackupPermissions(
// Check SELECT privilege on at least one table (if tables exist)
// Use pg_tables from pg_catalog which shows all tables regardless of user privileges
var tableCount int
if len(includeSchemas) > 0 {
// Check only tables in the specified schemas
err = conn.QueryRow(ctx, `
SELECT COUNT(*)
FROM pg_catalog.pg_tables t
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
AND t.schemaname NOT LIKE 'pg_temp_%'
AND t.schemaname NOT LIKE 'pg_toast_temp_%'
AND t.schemaname = ANY($1::text[])
`, includeSchemas).Scan(&tableCount)
} else {
@@ -775,6 +778,8 @@ func checkBackupPermissions(
SELECT COUNT(*)
FROM pg_catalog.pg_tables t
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
AND t.schemaname NOT LIKE 'pg_temp_%'
AND t.schemaname NOT LIKE 'pg_toast_temp_%'
`).Scan(&tableCount)
}
@@ -785,12 +790,15 @@ func checkBackupPermissions(
if tableCount > 0 {
// Check if user has SELECT on at least one of these tables
var selectableTableCount int
if len(includeSchemas) > 0 {
// Check only tables in the specified schemas
err = conn.QueryRow(ctx, `
SELECT COUNT(*)
FROM pg_catalog.pg_tables t
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
AND t.schemaname NOT LIKE 'pg_temp_%'
AND t.schemaname NOT LIKE 'pg_toast_temp_%'
AND t.schemaname = ANY($1::text[])
AND has_table_privilege(current_user, quote_ident(t.schemaname) || '.' || quote_ident(t.tablename), 'SELECT')
`, includeSchemas).Scan(&selectableTableCount)
@@ -800,6 +808,8 @@ func checkBackupPermissions(
SELECT COUNT(*)
FROM pg_catalog.pg_tables t
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
AND t.schemaname NOT LIKE 'pg_temp_%'
AND t.schemaname NOT LIKE 'pg_toast_temp_%'
AND has_table_privilege(current_user, quote_ident(t.schemaname) || '.' || quote_ident(t.tablename), 'SELECT')
`).Scan(&selectableTableCount)
}

View File

@@ -1,7 +1,7 @@
package healthcheck_attempt
import (
"databasus-backend/internal/config"
"context"
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
"log/slog"
"time"
@@ -13,18 +13,19 @@ type HealthcheckAttemptBackgroundService struct {
logger *slog.Logger
}
func (s *HealthcheckAttemptBackgroundService) Run() {
func (s *HealthcheckAttemptBackgroundService) Run(ctx context.Context) {
// first healthcheck immediately
s.checkDatabases()
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for range ticker.C {
if config.IsShouldShutdown() {
break
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.checkDatabases()
}
s.checkDatabases()
}
}

View File

@@ -675,6 +675,10 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
WebhookNotifier: &webhook_notifier.WebhookNotifier{
WebhookURL: "https://webhook.example.com/test",
WebhookMethod: webhook_notifier.WebhookMethodPOST,
Headers: []webhook_notifier.WebhookHeader{
{Key: "Authorization", Value: "Bearer my-secret-token"},
{Key: "X-Custom-Header", Value: "custom-value"},
},
},
}
},
@@ -687,14 +691,40 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
WebhookNotifier: &webhook_notifier.WebhookNotifier{
WebhookURL: "https://webhook.example.com/updated",
WebhookMethod: webhook_notifier.WebhookMethodGET,
Headers: []webhook_notifier.WebhookHeader{
{Key: "Authorization", Value: "Bearer updated-token"},
},
},
}
},
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
// No sensitive data to verify for webhook
assert.NotEmpty(
t,
notifier.WebhookNotifier.WebhookURL,
"WebhookURL should be visible",
)
// Verify header values are encrypted in DB
assert.True(
t,
isEncrypted(notifier.WebhookNotifier.Headers[0].Value),
"Header value should be encrypted in DB",
)
decrypted := decryptField(
t,
notifier.ID,
notifier.WebhookNotifier.Headers[0].Value,
)
assert.Equal(t, "Bearer updated-token", decrypted)
},
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
// No sensitive data to hide for webhook
assert.NotEmpty(
t,
notifier.WebhookNotifier.WebhookURL,
"WebhookURL should be visible",
)
for _, header := range notifier.WebhookNotifier.Headers {
assert.Empty(t, header.Value, "Header value should be hidden")
}
},
},
}
@@ -905,7 +935,7 @@ func Test_CreateNotifier_AllSensitiveFieldsEncryptedInDB(t *testing.T) {
},
},
{
name: "Webhook Notifier - WebhookURL encrypted",
name: "Webhook Notifier - Header values encrypted, URL not encrypted",
createNotifier: func(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
@@ -914,17 +944,48 @@ func Test_CreateNotifier_AllSensitiveFieldsEncryptedInDB(t *testing.T) {
WebhookNotifier: &webhook_notifier.WebhookNotifier{
WebhookURL: "https://webhook.example.com/test456",
WebhookMethod: webhook_notifier.WebhookMethodPOST,
Headers: []webhook_notifier.WebhookHeader{
{Key: "Authorization", Value: "Bearer secret-token-12345"},
{Key: "X-API-Key", Value: "api-key-67890"},
},
},
}
},
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
assert.True(
assert.False(
t,
isEncrypted(notifier.WebhookNotifier.WebhookURL),
"WebhookURL should be encrypted",
"WebhookURL should NOT be encrypted",
)
decrypted := decryptField(t, notifier.ID, notifier.WebhookNotifier.WebhookURL)
assert.Equal(t, "https://webhook.example.com/test456", decrypted)
assert.Equal(
t,
"https://webhook.example.com/test456",
notifier.WebhookNotifier.WebhookURL,
)
assert.True(
t,
isEncrypted(notifier.WebhookNotifier.Headers[0].Value),
"Header value should be encrypted",
)
decrypted1 := decryptField(
t,
notifier.ID,
notifier.WebhookNotifier.Headers[0].Value,
)
assert.Equal(t, "Bearer secret-token-12345", decrypted1)
assert.True(
t,
isEncrypted(notifier.WebhookNotifier.Headers[1].Value),
"Header value should be encrypted",
)
decrypted2 := decryptField(
t,
notifier.ID,
notifier.WebhookNotifier.Headers[1].Value,
)
assert.Equal(t, "api-key-67890", decrypted2)
},
},
}

View File

@@ -21,6 +21,10 @@ type WebhookHeader struct {
Value string `json:"value"`
}
// Before both WebhookURL, BodyTemplate and HeadersJSON were considered
// as sensetive data and it was causing issues. Now only headers values
// considered as sensetive data, but we try to decrypt webhook URL and
// body template for backward combability
type WebhookNotifier struct {
NotifierID uuid.UUID `json:"notifierId" gorm:"primaryKey;column:notifier_id"`
WebhookURL string `json:"webhookUrl" gorm:"not null;column:webhook_url"`
@@ -58,6 +62,20 @@ func (t *WebhookNotifier) AfterFind(_ *gorm.DB) error {
}
}
encryptor := encryption.GetFieldEncryptor()
if t.WebhookURL != "" {
if decrypted, err := encryptor.Decrypt(t.NotifierID, t.WebhookURL); err == nil {
t.WebhookURL = decrypted
}
}
if t.BodyTemplate != nil && *t.BodyTemplate != "" {
if decrypted, err := encryptor.Decrypt(t.NotifierID, *t.BodyTemplate); err == nil {
t.BodyTemplate = &decrypted
}
}
return nil
}
@@ -79,22 +97,24 @@ func (t *WebhookNotifier) Send(
heading string,
message string,
) error {
webhookURL, err := encryptor.Decrypt(t.NotifierID, t.WebhookURL)
if err != nil {
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
if err := t.decryptHeadersForSending(encryptor); err != nil {
return err
}
switch t.WebhookMethod {
case WebhookMethodGET:
return t.sendGET(webhookURL, heading, message, logger)
return t.sendGET(t.WebhookURL, heading, message, logger)
case WebhookMethodPOST:
return t.sendPOST(webhookURL, heading, message, logger)
return t.sendPOST(t.WebhookURL, heading, message, logger)
default:
return fmt.Errorf("unsupported webhook method: %s", t.WebhookMethod)
}
}
func (t *WebhookNotifier) HideSensitiveData() {
for i := range t.Headers {
t.Headers[i].Value = ""
}
}
func (t *WebhookNotifier) Update(incoming *WebhookNotifier) {
@@ -105,14 +125,15 @@ func (t *WebhookNotifier) Update(incoming *WebhookNotifier) {
}
func (t *WebhookNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if t.WebhookURL != "" {
encrypted, err := encryptor.Encrypt(t.NotifierID, t.WebhookURL)
for i := range t.Headers {
if t.Headers[i].Value != "" {
encrypted, err := encryptor.Encrypt(t.NotifierID, t.Headers[i].Value)
if err != nil {
return fmt.Errorf("failed to encrypt header value: %w", err)
}
if err != nil {
return fmt.Errorf("failed to encrypt webhook URL: %w", err)
t.Headers[i].Value = encrypted
}
t.WebhookURL = encrypted
}
return nil
@@ -241,3 +262,15 @@ func escapeJSONString(s string) string {
return string(b[1 : len(b)-1])
}
func (t *WebhookNotifier) decryptHeadersForSending(encryptor encryption.FieldEncryptor) error {
for i := range t.Headers {
if t.Headers[i].Value != "" {
if decrypted, err := encryptor.Decrypt(t.NotifierID, t.Headers[i].Value); err == nil {
t.Headers[i].Value = decrypted
}
}
}
return nil
}

View File

@@ -1,6 +1,7 @@
package restores
import (
"context"
"databasus-backend/internal/features/restores/enums"
"log/slog"
)
@@ -10,7 +11,7 @@ type RestoreBackgroundService struct {
logger *slog.Logger
}
func (s *RestoreBackgroundService) Run() {
func (s *RestoreBackgroundService) Run(ctx context.Context) {
if err := s.failRestoresInProgress(); err != nil {
s.logger.Error("Failed to fail restores in progress", "error", err)
panic(err)

View File

@@ -19,6 +19,7 @@ import (
"databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/mysql"
@@ -274,7 +275,7 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
var backup *backups.Backup
var backup *backups_core.Backup
var request RestoreBackupRequest
if tc.dbType == databases.DatabaseTypePostgres {
@@ -321,7 +322,7 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
}
// Set huge backup size (10 TB) that would fail disk validation if checked
repo := &backups.BackupRepository{}
repo := &backups_core.BackupRepository{}
backup.BackupSizeMb = 10485760.0
err := repo.Save(backup)
assert.NoError(t, err)
@@ -368,7 +369,7 @@ func createTestDatabaseWithBackupForRestore(
workspace *workspaces_models.Workspace,
owner *users_dto.SignInResponseDTO,
router *gin.Engine,
) (*databases.Database, *backups.Backup) {
) (*databases.Database, *backups_core.Backup) {
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
@@ -504,7 +505,7 @@ func createTestStorage(workspaceID uuid.UUID) *storages.Storage {
func createTestBackup(
database *databases.Database,
owner *users_dto.SignInResponseDTO,
) *backups.Backup {
) *backups_core.Backup {
fieldEncryptor := util_encryption.GetFieldEncryptor()
userService := users_services.GetUserService()
user, err := userService.GetUserFromToken(owner.Token)
@@ -517,17 +518,17 @@ func createTestBackup(
panic("No storage found for workspace")
}
backup := &backups.Backup{
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storages[0].ID,
Status: backups.BackupStatusCompleted,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10.5,
BackupDurationMs: 1000,
CreatedAt: time.Now().UTC(),
}
repo := &backups.BackupRepository{}
repo := &backups_core.BackupRepository{}
if err := repo.Save(backup); err != nil {
panic(err)
}

View File

@@ -1,7 +1,7 @@
package models
import (
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/restores/enums"
"time"
@@ -13,7 +13,7 @@ type Restore struct {
Status enums.RestoreStatus `json:"status" gorm:"column:status;type:text;not null"`
BackupID uuid.UUID `json:"backupId" gorm:"column:backup_id;type:uuid;not null"`
Backup *backups.Backup
Backup *backups_core.Backup
FailMessage *string `json:"failMessage" gorm:"column:fail_message"`

View File

@@ -3,6 +3,7 @@ package restores
import (
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/disk"
@@ -36,7 +37,7 @@ type RestoreService struct {
diskService *disk.DiskService
}
func (s *RestoreService) OnBeforeBackupRemove(backup *backups.Backup) error {
func (s *RestoreService) OnBeforeBackupRemove(backup *backups_core.Backup) error {
restores, err := s.restoreRepository.FindByBackupID(backup.ID)
if err != nil {
return err
@@ -153,10 +154,10 @@ func (s *RestoreService) RestoreBackupWithAuth(
}
func (s *RestoreService) RestoreBackup(
backup *backups.Backup,
backup *backups_core.Backup,
requestDTO RestoreBackupRequest,
) error {
if backup.Status != backups.BackupStatusCompleted {
if backup.Status != backups_core.BackupStatusCompleted {
return errors.New("backup is not completed")
}
@@ -370,7 +371,7 @@ func (s *RestoreService) validateVersionCompatibility(
}
func (s *RestoreService) validateDiskSpace(
backup *backups.Backup,
backup *backups_core.Backup,
requestDTO RestoreBackupRequest,
) error {
// Only validate disk space for PostgreSQL when file-based restore is needed:

View File

@@ -18,7 +18,7 @@ import (
"github.com/klauspost/compress/zstd"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -40,7 +40,7 @@ func (uc *RestoreMariadbBackupUsecase) Execute(
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
restore models.Restore,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
) error {
if originalDB.Type != databases.DatabaseTypeMariadb {
@@ -99,7 +99,7 @@ func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
mariadbBin string,
args []string,
password string,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
mdbConfig *mariadbtypes.MariadbDatabase,
) error {
@@ -163,7 +163,7 @@ func (uc *RestoreMariadbBackupUsecase) executeMariadbRestore(
args []string,
myCnfFile string,
backupReader io.ReadCloser,
backup *backups.Backup,
backup *backups_core.Backup,
) error {
fullArgs := append([]string{"--defaults-file=" + myCnfFile}, args...)
@@ -226,7 +226,7 @@ func (uc *RestoreMariadbBackupUsecase) executeMariadbRestore(
func (uc *RestoreMariadbBackupUsecase) setupDecryption(
reader io.Reader,
backup *backups.Backup,
backup *backups_core.Backup,
) (io.Reader, error) {
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
return nil, fmt.Errorf("backup is encrypted but missing encryption metadata")

View File

@@ -14,7 +14,7 @@ import (
"time"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -40,7 +40,7 @@ func (uc *RestoreMongodbBackupUsecase) Execute(
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
restore models.Restore,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
) error {
if originalDB.Type != databases.DatabaseTypeMongodb {
@@ -124,7 +124,7 @@ func (uc *RestoreMongodbBackupUsecase) buildMongorestoreArgs(
func (uc *RestoreMongodbBackupUsecase) restoreFromStorage(
mongorestoreBin string,
args []string,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
) error {
ctx, cancel := context.WithTimeout(context.Background(), restoreTimeout)
@@ -166,7 +166,7 @@ func (uc *RestoreMongodbBackupUsecase) executeMongoRestore(
mongorestoreBin string,
args []string,
backupReader io.ReadCloser,
backup *backups.Backup,
backup *backups_core.Backup,
) error {
cmd := exec.CommandContext(ctx, mongorestoreBin, args...)
@@ -231,7 +231,7 @@ func (uc *RestoreMongodbBackupUsecase) executeMongoRestore(
func (uc *RestoreMongodbBackupUsecase) setupDecryption(
reader io.Reader,
backup *backups.Backup,
backup *backups_core.Backup,
) (io.Reader, error) {
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
return nil, errors.New("encrypted backup missing salt or IV")

View File

@@ -18,7 +18,7 @@ import (
"github.com/klauspost/compress/zstd"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -40,7 +40,7 @@ func (uc *RestoreMysqlBackupUsecase) Execute(
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
restore models.Restore,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
) error {
if originalDB.Type != databases.DatabaseTypeMysql {
@@ -98,7 +98,7 @@ func (uc *RestoreMysqlBackupUsecase) restoreFromStorage(
mysqlBin string,
args []string,
password string,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
myConfig *mysqltypes.MysqlDatabase,
) error {
@@ -154,7 +154,7 @@ func (uc *RestoreMysqlBackupUsecase) executeMysqlRestore(
args []string,
myCnfFile string,
backupReader io.ReadCloser,
backup *backups.Backup,
backup *backups_core.Backup,
) error {
fullArgs := append([]string{"--defaults-file=" + myCnfFile}, args...)
@@ -217,7 +217,7 @@ func (uc *RestoreMysqlBackupUsecase) executeMysqlRestore(
func (uc *RestoreMysqlBackupUsecase) setupDecryption(
reader io.Reader,
backup *backups.Backup,
backup *backups_core.Backup,
) (io.Reader, error) {
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
return nil, fmt.Errorf("backup is encrypted but missing encryption metadata")

View File

@@ -15,7 +15,7 @@ import (
"time"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -39,7 +39,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
restore models.Restore,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
isExcludeExtensions bool,
) error {
@@ -86,7 +86,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
func (uc *RestorePostgresqlBackupUsecase) restoreCustomType(
originalDB *databases.Database,
pgBin string,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
pg *pgtypes.PostgresqlDatabase,
isExcludeExtensions bool,
@@ -113,7 +113,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreCustomType(
func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
originalDB *databases.Database,
pgBin string,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
pg *pgtypes.PostgresqlDatabase,
) error {
@@ -321,7 +321,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
func (uc *RestorePostgresqlBackupUsecase) restoreViaFile(
originalDB *databases.Database,
pgBin string,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
pg *pgtypes.PostgresqlDatabase,
isExcludeExtensions bool,
@@ -371,7 +371,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
pgBin string,
args []string,
password string,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
pgConfig *pgtypes.PostgresqlDatabase,
isExcludeExtensions bool,
@@ -469,7 +469,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
// downloadBackupToTempFile downloads backup data from storage to a temporary file
func (uc *RestorePostgresqlBackupUsecase) downloadBackupToTempFile(
ctx context.Context,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
) (string, func(), error) {
// Create temporary directory for backup data

View File

@@ -3,7 +3,7 @@ package usecases
import (
"errors"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/restores/models"
@@ -26,7 +26,7 @@ func (uc *RestoreBackupUsecase) Execute(
restore models.Restore,
originalDB *databases.Database,
restoringToDB *databases.Database,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
isExcludeExtensions bool,
) error {

View File

@@ -1,13 +1,14 @@
package system_healthcheck
import (
"databasus-backend/internal/features/backups/backups"
"databasus-backend/internal/features/backups/backups/backuping"
"databasus-backend/internal/features/disk"
)
var healthcheckService = &HealthcheckService{
disk.GetDiskService(),
backups.GetBackupBackgroundService(),
backuping.GetBackupsScheduler(),
backuping.GetBackuperNode(),
}
var healthcheckController = &HealthcheckController{
healthcheckService,

View File

@@ -1,7 +1,8 @@
package system_healthcheck
import (
"databasus-backend/internal/features/backups/backups"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups/backuping"
"databasus-backend/internal/features/disk"
"databasus-backend/internal/storage"
"errors"
@@ -9,7 +10,8 @@ import (
type HealthcheckService struct {
diskService *disk.DiskService
backupBackgroundService *backups.BackupBackgroundService
backupBackgroundService *backuping.BackupsScheduler
backuperNode *backuping.BackuperNode
}
func (s *HealthcheckService) IsHealthy() error {
@@ -29,8 +31,16 @@ func (s *HealthcheckService) IsHealthy() error {
return errors.New("cannot connect to the database")
}
if !s.backupBackgroundService.IsBackupsWorkerRunning() {
return errors.New("backups are not running for more than 5 minutes")
if config.GetEnv().IsPrimaryNode {
if !s.backupBackgroundService.IsSchedulerRunning() {
return errors.New("backups are not running for more than 5 minutes")
}
}
if config.GetEnv().IsBackupNode {
if !s.backuperNode.IsBackuperRunning() {
return errors.New("backuper node is not running for more than 5 minutes")
}
}
return nil

View File

@@ -17,7 +17,7 @@ import (
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
mariadbtypes "databasus-backend/internal/features/databases/databases/mariadb"
@@ -189,7 +189,7 @@ func testMariadbBackupRestoreForVersion(
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb_mariadb"
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -286,7 +286,7 @@ func testMariadbBackupRestoreWithEncryptionForVersion(
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_config.BackupEncryptionEncrypted, backup.Encryption)
newDBName := "restoreddb_mariadb_encrypted"
@@ -394,7 +394,7 @@ func testMariadbBackupRestoreWithReadOnlyUserForVersion(
createBackupViaAPI(t, router, updatedDatabase.ID, user.Token)
backup := waitForBackupCompletion(t, router, updatedDatabase.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb_mariadb_readonly"
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))

View File

@@ -19,7 +19,7 @@ import (
"go.mongodb.org/mongo-driver/mongo/options"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
mongodbtypes "databasus-backend/internal/features/databases/databases/mongodb"
@@ -161,7 +161,7 @@ func testMongodbBackupRestoreForVersion(
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb_mongodb_" + uuid.New().String()[:8]
@@ -239,7 +239,7 @@ func testMongodbBackupRestoreWithEncryptionForVersion(
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_config.BackupEncryptionEncrypted, backup.Encryption)
newDBName := "restoreddb_mongodb_enc_" + uuid.New().String()[:8]
@@ -328,7 +328,7 @@ func testMongodbBackupRestoreWithReadOnlyUserForVersion(
createBackupViaAPI(t, router, updatedDatabase.ID, user.Token)
backup := waitForBackupCompletion(t, router, updatedDatabase.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb_mongodb_ro_" + uuid.New().String()[:8]

View File

@@ -17,7 +17,7 @@ import (
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
mysqltypes "databasus-backend/internal/features/databases/databases/mysql"
@@ -164,7 +164,7 @@ func testMysqlBackupRestoreForVersion(t *testing.T, mysqlVersion tools.MysqlVers
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb_mysql"
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -261,7 +261,7 @@ func testMysqlBackupRestoreWithEncryptionForVersion(
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_config.BackupEncryptionEncrypted, backup.Encryption)
newDBName := "restoreddb_mysql_encrypted"
@@ -369,7 +369,7 @@ func testMysqlBackupRestoreWithReadOnlyUserForVersion(
createBackupViaAPI(t, router, updatedDatabase.ID, user.Token)
backup := waitForBackupCompletion(t, router, updatedDatabase.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb_mysql_readonly"
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))

View File

@@ -18,6 +18,7 @@ import (
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
pgtypes "databasus-backend/internal/features/databases/databases/postgresql"
@@ -190,7 +191,7 @@ func Test_BackupAndRestoreSupabase_PublicSchemaOnly_RestoreIsSuccessful(t *testi
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
_, err = supabaseDB.Exec(fmt.Sprintf(`DELETE FROM public.%s`, tableName))
assert.NoError(t, err)
@@ -410,7 +411,7 @@ func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string, cp
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := fmt.Sprintf("restoreddb_%s_cpu%d_%s", pgVersion, cpuCount, uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -527,7 +528,7 @@ func testSchemaSelectionAllSchemasForVersion(t *testing.T, pgVersion string, por
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := fmt.Sprintf("restored_all_schemas_%s_%s", pgVersion, uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -655,7 +656,7 @@ func testBackupRestoreWithExcludeExtensionsForVersion(t *testing.T, pgVersion st
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := fmt.Sprintf("restored_exclude_ext_%s_%s", pgVersion, uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -789,7 +790,7 @@ func testBackupRestoreWithoutExcludeExtensionsForVersion(
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := fmt.Sprintf("restored_with_ext_%s_%s", pgVersion, uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -928,7 +929,7 @@ func testBackupRestoreWithReadOnlyUserForVersion(t *testing.T, pgVersion string,
createBackupViaAPI(t, router, updatedDatabase.ID, user.Token)
backup := waitForBackupCompletion(t, router, updatedDatabase.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := fmt.Sprintf("restoreddb_readonly_%s", uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -1048,7 +1049,7 @@ func testSchemaSelectionOnlySpecifiedSchemasForVersion(
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := fmt.Sprintf("restored_specific_schemas_%s_%s", pgVersion, uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -1161,7 +1162,7 @@ func testBackupRestoreWithEncryptionForVersion(t *testing.T, pgVersion string, p
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_config.BackupEncryptionEncrypted, backup.Encryption)
newDBName := fmt.Sprintf("restoreddb_encrypted_%s", uuid.New().String()[:8])
@@ -1242,7 +1243,7 @@ func waitForBackupCompletion(
databaseID uuid.UUID,
token string,
timeout time.Duration,
) *backups.Backup {
) *backups_core.Backup {
startTime := time.Now()
pollInterval := 500 * time.Millisecond
@@ -1263,10 +1264,10 @@ func waitForBackupCompletion(
if len(response.Backups) > 0 {
backup := response.Backups[0]
if backup.Status == backups.BackupStatusCompleted {
if backup.Status == backups_core.BackupStatusCompleted {
return backup
}
if backup.Status == backups.BackupStatusFailed {
if backup.Status == backups_core.BackupStatusFailed {
failMsg := "unknown error"
if backup.FailMessage != nil {
failMsg = *backup.FailMessage

View File

@@ -0,0 +1,22 @@
package tests
import (
"os"
"testing"
"databasus-backend/internal/features/backups/backups/backuping"
cache_utils "databasus-backend/internal/util/cache"
)
func TestMain(m *testing.M) {
cache_utils.ClearAllCache()
backuperNode := backuping.CreateTestBackuperNode()
cancel := backuping.StartBackuperNodeForTest(&testing.T{}, backuperNode)
exitCode := m.Run()
backuping.StopBackuperNodeForTest(&testing.T{}, cancel, backuperNode)
os.Exit(exitCode)
}

View File

@@ -2,13 +2,12 @@ package users_controllers
import (
users_services "databasus-backend/internal/features/users/services"
"golang.org/x/time/rate"
cache_utils "databasus-backend/internal/util/cache"
)
var userController = &UserController{
users_services.GetUserService(),
rate.NewLimiter(rate.Limit(3), 3), // 3 rps with 3 burst
cache_utils.NewRateLimiter(cache_utils.GetValkeyClient()),
}
var settingsController = &SettingsController{

View File

@@ -14,7 +14,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"golang.org/x/time/rate"
)
func Test_AdminLifecycleE2E_CompletesSuccessfully(t *testing.T) {
@@ -185,7 +184,6 @@ func createUserTestRouter() *gin.Engine {
// Register protected routes with auth middleware
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
GetUserController().RegisterProtectedRoutes(protected.(*gin.RouterGroup))
GetUserController().SetSignInLimiter(rate.NewLimiter(rate.Limit(100), 100))
// Setup audit log service
users_services.GetUserService().SetAuditLogWriter(&AuditLogWriterStub{})

View File

@@ -3,20 +3,21 @@ package users_controllers
import (
"errors"
"net/http"
"time"
"databasus-backend/internal/config"
user_dto "databasus-backend/internal/features/users/dto"
users_errors "databasus-backend/internal/features/users/errors"
user_middleware "databasus-backend/internal/features/users/middleware"
users_services "databasus-backend/internal/features/users/services"
cache_utils "databasus-backend/internal/util/cache"
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
)
type UserController struct {
userService *users_services.UserService
signinLimiter *rate.Limiter
userService *users_services.UserService
rateLimiter *cache_utils.RateLimiter
}
func (c *UserController) RegisterRoutes(router *gin.RouterGroup) {
@@ -39,10 +40,6 @@ func (c *UserController) RegisterProtectedRoutes(router *gin.RouterGroup) {
router.POST("/users/invite", c.InviteUser)
}
func (c *UserController) SetSignInLimiter(limiter *rate.Limiter) {
c.signinLimiter = limiter
}
// SignUp
// @Summary Register a new user
// @Description Register a new user with email and password
@@ -81,8 +78,14 @@ func (c *UserController) SignUp(ctx *gin.Context) {
// @Failure 429 {object} map[string]string "Rate limit exceeded"
// @Router /users/signin [post]
func (c *UserController) SignIn(ctx *gin.Context) {
// We use rate limiter to prevent brute force attacks
if !c.signinLimiter.Allow() {
var request user_dto.SignInRequestDTO
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"})
return
}
allowed, _ := c.rateLimiter.CheckLimit(request.Email, "signin", 10, 1*time.Minute)
if !allowed {
ctx.JSON(
http.StatusTooManyRequests,
gin.H{"error": "Rate limit exceeded. Please try again later."},
@@ -90,12 +93,6 @@ func (c *UserController) SignIn(ctx *gin.Context) {
return
}
var request user_dto.SignInRequestDTO
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"})
return
}
response, err := c.userService.SignIn(&request)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})

View File

@@ -1146,3 +1146,48 @@ func Test_GoogleOAuth_WithInvitedUser_ActivatesUser(t *testing.T) {
assert.Equal(t, email, response.Email)
assert.False(t, response.IsNewUser)
}
func Test_SignIn_WithExcessiveAttempts_RateLimitEnforced(t *testing.T) {
router := createUserTestRouter()
email := "ratelimit" + uuid.New().String() + "@example.com"
password := "testpassword123"
// Create a user first
signupRequest := users_dto.SignUpRequestDTO{
Email: email,
Password: password,
Name: "Rate Limit Test User",
}
test_utils.MakePostRequest(t, router, "/api/v1/users/signup", "", signupRequest, http.StatusOK)
// Make 10 sign-in attempts (should succeed)
for range 10 {
signinRequest := users_dto.SignInRequestDTO{
Email: email,
Password: password,
}
test_utils.MakePostRequest(
t,
router,
"/api/v1/users/signin",
"",
signinRequest,
http.StatusOK,
)
}
// 11th attempt should be rate limited
signinRequest := users_dto.SignInRequestDTO{
Email: email,
Password: password,
}
resp := test_utils.MakePostRequest(
t,
router,
"/api/v1/users/signin",
"",
signinRequest,
http.StatusTooManyRequests,
)
assert.Contains(t, string(resp.Body), "Rate limit exceeded")
}

View File

@@ -41,6 +41,10 @@ func getCache() valkey.Client {
return valkeyClient
}
func GetValkeyClient() valkey.Client {
return getCache()
}
func TestCacheConnection() {
// Get Valkey client from cache package
client := getCache()

View File

@@ -0,0 +1,85 @@
package cache_utils
import (
"context"
"fmt"
"time"
"github.com/google/uuid"
"github.com/valkey-io/valkey-go"
)
type RateLimiter struct {
client valkey.Client
}
func NewRateLimiter(client valkey.Client) *RateLimiter {
return &RateLimiter{
client: client,
}
}
func (r *RateLimiter) CheckLimit(
identifier string,
endpoint string,
maxRequests int,
windowDuration time.Duration,
) (bool, error) {
requestID := uuid.New().String()
keyPrefix := fmt.Sprintf("ratelimit:%s:%s", endpoint, identifier)
fullKey := fmt.Sprintf("%s:%s", keyPrefix, requestID)
ctx, cancel := context.WithTimeout(context.Background(), DefaultCacheTimeout)
defer cancel()
// Set the key with TTL
setCmd := r.client.B().
Set().
Key(fullKey).
Value("1").
ExSeconds(int64(windowDuration.Seconds())).
Build()
if err := r.client.Do(ctx, setCmd).Error(); err != nil {
return true, fmt.Errorf("failed to set rate limit key: %w", err)
}
// Count keys matching the pattern
count, err := r.countKeys(keyPrefix)
if err != nil {
return true, fmt.Errorf("failed to count rate limit keys: %w", err)
}
return count <= maxRequests, nil
}
func (r *RateLimiter) countKeys(keyPrefix string) (int, error) {
pattern := keyPrefix + ":*"
cursor := uint64(0)
totalCount := 0
for {
ctx, cancel := context.WithTimeout(context.Background(), DefaultCacheTimeout)
scanCmd := r.client.B().Scan().Cursor(cursor).Match(pattern).Count(100).Build()
result := r.client.Do(ctx, scanCmd)
cancel()
if result.Error() != nil {
return 0, result.Error()
}
scanResult, err := result.AsScanEntry()
if err != nil {
return 0, err
}
totalCount += len(scanResult.Elements)
cursor = scanResult.Cursor
if cursor == 0 {
break
}
}
return totalCount, nil
}

View File

@@ -112,6 +112,12 @@ export function EditWebhookNotifierComponent({ notifier, setNotifier, setUnsaved
</span>
</div>
{notifier.id && (
<div className="mb-1 text-xs text-orange-700">
*Saved headers hidden for security reasons
</div>
)}
<div className="w-full max-w-[500px]">
{headers.map((header: WebhookHeader, index: number) => (
<div key={index} className="mb-1 flex items-center gap-2">
@@ -204,11 +210,12 @@ export function EditWebhookNotifierComponent({ notifier, setNotifier, setUnsaved
<div className="text-xs font-semibold text-gray-500 dark:text-gray-400">
Headers:
</div>
{headers
.filter((h) => h.key)
.map((h, i) => (
<div key={i} className="text-xs">
{h.key}: {h.value || '(empty)'}
{h.key}: {h.value || '(hidden)'}
</div>
))}
</div>

View File

@@ -29,7 +29,7 @@ export function ShowWebhookNotifierComponent({ notifier }: Props) {
.filter((h: WebhookHeader) => h.key)
.map((h: WebhookHeader, i: number) => (
<div key={i} className="text-gray-600">
<span className="font-medium">{h.key}:</span> {h.value || '(empty)'}
<span className="font-medium">{h.key}:</span> {h.value || '(hidden)'}
</div>
))}
</div>