FEATURE (backups): Add support of multinode Databasus setup

This commit is contained in:
Rostislav Dugin
2026-01-14 07:24:37 +03:00
parent 54b9e67656
commit 80f1174ecd
52 changed files with 3386 additions and 882 deletions

View File

@@ -15,6 +15,9 @@ import (
"databasus-backend/internal/config"
"databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
"databasus-backend/internal/features/backups/backups/backuping"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_download "databasus-backend/internal/features/backups/backups/download"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/disk"
@@ -54,16 +57,23 @@ func main() {
log := logger.GetLogger()
cache_utils.TestCacheConnection()
err := cache_utils.ClearAllCache()
if err != nil {
log.Error("Failed to clear cache", "error", err)
os.Exit(1)
if config.GetEnv().IsPrimaryNode {
err := cache_utils.ClearAllCache()
if err != nil {
log.Error("Failed to clear cache", "error", err)
os.Exit(1)
}
}
runMigrations(log)
if config.GetEnv().IsPrimaryNode {
runMigrations(log)
} else {
log.Info("Skipping migrations (IS_PRIMARY_NODE is false)")
}
// create directories that used for backups and restore
err = files_utils.EnsureDirectories([]string{
err := files_utils.EnsureDirectories([]string{
config.GetEnv().TempFolder,
config.GetEnv().DataFolder,
})
@@ -104,7 +114,9 @@ func main() {
enableCors(ginApp)
setUpRoutes(ginApp)
setUpDependencies()
runBackgroundTasks(log)
mountFrontend(ginApp)
startServerWithGracefulShutdown(log, ginApp)
@@ -227,35 +239,64 @@ func setUpDependencies() {
notifiers.SetupDependencies()
storages.SetupDependencies()
backups_config.SetupDependencies()
backups_cancellation.SetupDependencies()
}
func runBackgroundTasks(log *slog.Logger) {
log.Info("Preparing to run background tasks...")
err := files_utils.CleanFolder(config.GetEnv().TempFolder)
if err != nil {
log.Error("Failed to clean temp folder", "error", err)
// Create context that will be cancelled on shutdown
ctx, cancel := context.WithCancel(context.Background())
// Set up signal handling for graceful shutdown
quit := make(chan os.Signal, 1)
signal.Notify(quit, os.Interrupt, syscall.SIGTERM)
go func() {
<-quit
log.Info("Shutdown signal received, cancelling all background tasks")
cancel()
}()
if config.GetEnv().IsPrimaryNode {
log.Info("Starting primary node background tasks...")
err := files_utils.CleanFolder(config.GetEnv().TempFolder)
if err != nil {
log.Error("Failed to clean temp folder", "error", err)
}
go runWithPanicLogging(log, "backup background service", func() {
backuping.GetBackupsScheduler().Run(ctx)
})
go runWithPanicLogging(log, "restore background service", func() {
restores.GetRestoreBackgroundService().Run(ctx)
})
go runWithPanicLogging(log, "healthcheck attempt background service", func() {
healthcheck_attempt.GetHealthcheckAttemptBackgroundService().Run(ctx)
})
go runWithPanicLogging(log, "audit log cleanup background service", func() {
audit_logs.GetAuditLogBackgroundService().Run(ctx)
})
go runWithPanicLogging(log, "download token cleanup background service", func() {
backups_download.GetDownloadTokenBackgroundService().Run(ctx)
})
} else {
log.Info("Skipping primary node tasks as not primary node")
}
go runWithPanicLogging(log, "backup background service", func() {
backups.GetBackupBackgroundService().Run()
})
if config.GetEnv().IsBackupNode {
log.Info("Starting backup node background tasks...")
go runWithPanicLogging(log, "restore background service", func() {
restores.GetRestoreBackgroundService().Run()
})
go runWithPanicLogging(log, "healthcheck attempt background service", func() {
healthcheck_attempt.GetHealthcheckAttemptBackgroundService().Run()
})
go runWithPanicLogging(log, "audit log cleanup background service", func() {
audit_logs.GetAuditLogBackgroundService().Run()
})
go runWithPanicLogging(log, "download token cleanup background service", func() {
backups.GetDownloadTokenBackgroundService().Run()
})
go runWithPanicLogging(log, "backup node", func() {
backuping.GetBackuperNode().Run(ctx)
})
} else {
log.Info("Skipping backup node tasks as not backup node")
}
}
func runWithPanicLogging(log *slog.Logger, serviceName string, fn func()) {

View File

@@ -9,6 +9,7 @@ import (
"strings"
"sync"
"github.com/google/uuid"
"github.com/ilyakaznacheev/cleanenv"
"github.com/joho/godotenv"
)
@@ -29,6 +30,12 @@ type EnvVariables struct {
MariadbInstallDir string `env:"MARIADB_INSTALL_DIR"`
MongodbInstallDir string `env:"MONGODB_INSTALL_DIR"`
NodeID string
IsManyNodesMode bool `env:"IS_MANY_NODES_MODE"`
IsPrimaryNode bool `env:"IS_PRIMARY_NODE"`
IsBackupNode bool `env:"IS_BACKUP_NODE"`
NodeNetworkThroughputMBs int `env:"NODE_NETWORK_THROUGHPUT_MBPS"`
DataFolder string
TempFolder string
SecretKeyPath string
@@ -196,6 +203,16 @@ func loadEnvVariables() {
env.MongodbInstallDir = filepath.Join(backendRoot, "tools", "mongodb")
tools.VerifyMongodbInstallation(log, env.EnvMode, env.MongodbInstallDir)
env.NodeID = uuid.New().String()
if env.NodeNetworkThroughputMBs == 0 {
env.NodeNetworkThroughputMBs = 125 // 1 Gbit/s
}
if !env.IsManyNodesMode {
env.IsPrimaryNode = true
env.IsBackupNode = true
}
// Valkey
if env.ValkeyHost == "" {
log.Error("VALKEY_HOST is empty")

View File

@@ -1,7 +1,7 @@
package audit_logs
import (
"databasus-backend/internal/config"
"context"
"log/slog"
"time"
)
@@ -11,23 +11,25 @@ type AuditLogBackgroundService struct {
logger *slog.Logger
}
func (s *AuditLogBackgroundService) Run() {
func (s *AuditLogBackgroundService) Run(ctx context.Context) {
s.logger.Info("Starting audit log cleanup background service")
if config.IsShouldShutdown() {
if ctx.Err() != nil {
return
}
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for {
if config.IsShouldShutdown() {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.cleanOldAuditLogs(); err != nil {
s.logger.Error("Failed to clean old audit logs", "error", err)
}
}
if err := s.cleanOldAuditLogs(); err != nil {
s.logger.Error("Failed to clean old audit logs", "error", err)
}
time.Sleep(1 * time.Hour)
}
}

View File

@@ -1,254 +0,0 @@
package backups
import (
"databasus-backend/internal/config"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/storages"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/period"
"log/slog"
"time"
)
type BackupBackgroundService struct {
backupService *BackupService
backupRepository *BackupRepository
backupConfigService *backups_config.BackupConfigService
storageService *storages.StorageService
lastBackupTime time.Time
logger *slog.Logger
}
func (s *BackupBackgroundService) Run() {
s.lastBackupTime = time.Now().UTC()
if err := s.failBackupsInProgress(); err != nil {
s.logger.Error("Failed to fail backups in progress", "error", err)
panic(err)
}
if config.IsShouldShutdown() {
return
}
for {
if config.IsShouldShutdown() {
return
}
if err := s.cleanOldBackups(); err != nil {
s.logger.Error("Failed to clean old backups", "error", err)
}
if err := s.runPendingBackups(); err != nil {
s.logger.Error("Failed to run pending backups", "error", err)
}
s.lastBackupTime = time.Now().UTC()
time.Sleep(1 * time.Minute)
}
}
func (s *BackupBackgroundService) IsBackupsWorkerRunning() bool {
// if last backup time is more than 5 minutes ago, return false
return s.lastBackupTime.After(time.Now().UTC().Add(-5 * time.Minute))
}
func (s *BackupBackgroundService) failBackupsInProgress() error {
backupsInProgress, err := s.backupRepository.FindByStatus(BackupStatusInProgress)
if err != nil {
return err
}
for _, backup := range backupsInProgress {
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(backup.DatabaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
continue
}
failMessage := "Backup failed due to application restart"
backup.FailMessage = &failMessage
backup.Status = BackupStatusFailed
backup.BackupSizeMb = 0
s.backupService.SendBackupNotification(
backupConfig,
backup,
backups_config.NotificationBackupFailed,
&failMessage,
)
if err := s.backupRepository.Save(backup); err != nil {
return err
}
}
return nil
}
func (s *BackupBackgroundService) cleanOldBackups() error {
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
backupStorePeriod := backupConfig.StorePeriod
if backupStorePeriod == period.PeriodForever {
continue
}
storeDuration := backupStorePeriod.ToDuration()
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
oldBackups, err := s.backupRepository.FindBackupsBeforeDate(
backupConfig.DatabaseID,
dateBeforeBackupsShouldBeDeleted,
)
if err != nil {
s.logger.Error(
"Failed to find old backups for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
continue
}
for _, backup := range oldBackups {
storage, err := s.storageService.GetStorageByID(backup.StorageID)
if err != nil {
s.logger.Error(
"Failed to get storage by ID",
"storageId",
backup.StorageID,
"error",
err,
)
continue
}
encryptor := encryption.GetFieldEncryptor()
err = storage.DeleteFile(encryptor, backup.ID)
if err != nil {
s.logger.Error("Failed to delete backup file", "backupId", backup.ID, "error", err)
}
if err := s.backupRepository.DeleteByID(backup.ID); err != nil {
s.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
continue
}
s.logger.Info(
"Deleted old backup",
"backupId",
backup.ID,
"databaseId",
backupConfig.DatabaseID,
)
}
}
return nil
}
func (s *BackupBackgroundService) runPendingBackups() error {
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
if backupConfig.BackupInterval == nil {
continue
}
lastBackup, err := s.backupRepository.FindLastByDatabaseID(backupConfig.DatabaseID)
if err != nil {
s.logger.Error(
"Failed to get last backup for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
continue
}
var lastBackupTime *time.Time
if lastBackup != nil {
lastBackupTime = &lastBackup.CreatedAt
}
remainedBackupTryCount := s.GetRemainedBackupTryCount(lastBackup)
if backupConfig.BackupInterval.ShouldTriggerBackup(time.Now().UTC(), lastBackupTime) ||
remainedBackupTryCount > 0 {
s.logger.Info(
"Triggering scheduled backup",
"databaseId",
backupConfig.DatabaseID,
"intervalType",
backupConfig.BackupInterval.Interval,
)
go s.backupService.MakeBackup(backupConfig.DatabaseID, remainedBackupTryCount == 1)
s.logger.Info(
"Successfully triggered scheduled backup",
"databaseId",
backupConfig.DatabaseID,
)
}
}
return nil
}
// GetRemainedBackupTryCount returns the number of remaining backup tries for a given backup.
// If the backup is not failed or the backup config does not allow retries, it returns 0.
// If the backup is failed and the backup config allows retries, it returns the number of remaining tries.
// If the backup is failed and the backup config does not allow retries, it returns 0.
func (s *BackupBackgroundService) GetRemainedBackupTryCount(lastBackup *Backup) int {
if lastBackup == nil {
return 0
}
if lastBackup.Status != BackupStatusFailed {
return 0
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(lastBackup.DatabaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
return 0
}
if !backupConfig.IsRetryIfFailed {
return 0
}
maxFailedTriesCount := backupConfig.MaxFailedTriesCount
lastBackups, err := s.backupRepository.FindByDatabaseIDWithLimit(
lastBackup.DatabaseID,
maxFailedTriesCount,
)
if err != nil {
s.logger.Error("Failed to find last backups by database ID", "error", err)
return 0
}
lastFailedBackups := make([]*Backup, 0)
for _, backup := range lastBackups {
if backup.Status == BackupStatusFailed {
lastFailedBackups = append(lastFailedBackups, backup)
}
}
return maxFailedTriesCount - len(lastFailedBackups)
}

View File

@@ -0,0 +1,344 @@
package backuping
import (
"context"
"databasus-backend/internal/config"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/storages"
workspaces_services "databasus-backend/internal/features/workspaces/services"
util_encryption "databasus-backend/internal/util/encryption"
"errors"
"fmt"
"log/slog"
"slices"
"strings"
"time"
"github.com/google/uuid"
)
type BackuperNode struct {
databaseService *databases.DatabaseService
fieldEncryptor util_encryption.FieldEncryptor
workspaceService *workspaces_services.WorkspaceService
backupRepository *backups_core.BackupRepository
backupConfigService *backups_config.BackupConfigService
storageService *storages.StorageService
notificationSender backups_core.NotificationSender
backupCancelManager *backups_cancellation.BackupCancelManager
nodesRegistry *BackupNodesRegistry
logger *slog.Logger
createBackupUseCase backups_core.CreateBackupUsecase
nodeID uuid.UUID
lastHeartbeat time.Time
}
func (n *BackuperNode) Run(ctx context.Context) {
n.lastHeartbeat = time.Now().UTC()
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
backupNode := BackupNode{
ID: n.nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: time.Now().UTC(),
}
if err := n.nodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
n.logger.Error("Failed to register node in registry", "error", err)
panic(err)
}
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
n.MakeBackup(backupID, isCallNotifier)
if err := n.nodesRegistry.PublishBackupCompletion(n.nodeID.String(), backupID); err != nil {
n.logger.Error(
"Failed to publish backup completion",
"error",
err,
"backupID",
backupID,
)
}
}
if err := n.nodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID.String(), backupHandler); err != nil {
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
panic(err)
}
defer func() {
if err := n.nodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
}
}()
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
for {
select {
case <-ctx.Done():
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
if err := n.nodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
n.logger.Error("Failed to unregister node from registry", "error", err)
}
return
case <-ticker.C:
n.sendHeartbeat(&backupNode)
}
}
}
func (n *BackuperNode) IsBackuperRunning() bool {
return n.lastHeartbeat.After(time.Now().UTC().Add(-5 * time.Minute))
}
func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
backup, err := n.backupRepository.FindByID(backupID)
if err != nil {
n.logger.Error("Failed to get backup by ID", "backupId", backupID, "error", err)
return
}
databaseID := backup.DatabaseID
database, err := n.databaseService.GetDatabaseByID(databaseID)
if err != nil {
n.logger.Error("Failed to get database by ID", "databaseId", databaseID, "error", err)
return
}
backupConfig, err := n.backupConfigService.GetBackupConfigByDbId(databaseID)
if err != nil {
n.logger.Error("Failed to get backup config by database ID", "error", err)
return
}
if backupConfig.StorageID == nil {
n.logger.Error("Backup config storage ID is not defined")
return
}
storage, err := n.storageService.GetStorageByID(*backupConfig.StorageID)
if err != nil {
n.logger.Error("Failed to get storage by ID", "error", err)
return
}
start := time.Now().UTC()
backupProgressListener := func(
completedMBs float64,
) {
backup.BackupSizeMb = completedMBs
backup.BackupDurationMs = time.Since(start).Milliseconds()
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to update backup progress", "error", err)
}
}
ctx, cancel := context.WithCancel(context.Background())
n.backupCancelManager.RegisterBackup(backup.ID, cancel)
defer n.backupCancelManager.UnregisterBackup(backup.ID)
backupMetadata, err := n.createBackupUseCase.Execute(
ctx,
backup.ID,
backupConfig,
database,
storage,
backupProgressListener,
)
if err != nil {
errMsg := err.Error()
// Check if backup was cancelled (not due to shutdown)
isCancelled := strings.Contains(errMsg, "backup cancelled") ||
strings.Contains(errMsg, "context canceled") ||
errors.Is(err, context.Canceled)
isShutdown := strings.Contains(errMsg, "shutdown")
if isCancelled && !isShutdown {
backup.Status = backups_core.BackupStatusCanceled
backup.BackupDurationMs = time.Since(start).Milliseconds()
backup.BackupSizeMb = 0
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to save cancelled backup", "error", err)
}
// Delete partial backup from storage
storage, storageErr := n.storageService.GetStorageByID(backup.StorageID)
if storageErr == nil {
if deleteErr := storage.DeleteFile(n.fieldEncryptor, backup.ID); deleteErr != nil {
n.logger.Error(
"Failed to delete partial backup file",
"backupId",
backup.ID,
"error",
deleteErr,
)
}
}
return
}
backup.FailMessage = &errMsg
backup.Status = backups_core.BackupStatusFailed
backup.BackupDurationMs = time.Since(start).Milliseconds()
backup.BackupSizeMb = 0
if updateErr := n.databaseService.SetBackupError(databaseID, errMsg); updateErr != nil {
n.logger.Error(
"Failed to update database last backup time",
"databaseId",
databaseID,
"error",
updateErr,
)
}
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to save backup", "error", err)
}
n.SendBackupNotification(
backupConfig,
backup,
backups_config.NotificationBackupFailed,
&errMsg,
)
return
}
backup.Status = backups_core.BackupStatusCompleted
backup.BackupDurationMs = time.Since(start).Milliseconds()
// Update backup with encryption metadata if provided
if backupMetadata != nil {
backup.EncryptionSalt = backupMetadata.EncryptionSalt
backup.EncryptionIV = backupMetadata.EncryptionIV
backup.Encryption = backupMetadata.Encryption
}
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to save backup", "error", err)
return
}
// Update database last backup time
now := time.Now().UTC()
if updateErr := n.databaseService.SetLastBackupTime(databaseID, now); updateErr != nil {
n.logger.Error(
"Failed to update database last backup time",
"databaseId",
databaseID,
"error",
updateErr,
)
}
if backup.Status != backups_core.BackupStatusCompleted && !isCallNotifier {
return
}
n.SendBackupNotification(
backupConfig,
backup,
backups_config.NotificationBackupSuccess,
nil,
)
}
func (n *BackuperNode) SendBackupNotification(
backupConfig *backups_config.BackupConfig,
backup *backups_core.Backup,
notificationType backups_config.BackupNotificationType,
errorMessage *string,
) {
database, err := n.databaseService.GetDatabaseByID(backupConfig.DatabaseID)
if err != nil {
return
}
workspace, err := n.workspaceService.GetWorkspaceByID(*database.WorkspaceID)
if err != nil {
return
}
for _, notifier := range database.Notifiers {
if !slices.Contains(
backupConfig.SendNotificationsOn,
notificationType,
) {
continue
}
title := ""
switch notificationType {
case backups_config.NotificationBackupFailed:
title = fmt.Sprintf(
"❌ Backup failed for database \"%s\" (workspace \"%s\")",
database.Name,
workspace.Name,
)
case backups_config.NotificationBackupSuccess:
title = fmt.Sprintf(
"✅ Backup completed for database \"%s\" (workspace \"%s\")",
database.Name,
workspace.Name,
)
}
message := ""
if errorMessage != nil {
message = *errorMessage
} else {
// Format size conditionally
var sizeStr string
if backup.BackupSizeMb < 1024 {
sizeStr = fmt.Sprintf("%.2f MB", backup.BackupSizeMb)
} else {
sizeGB := backup.BackupSizeMb / 1024
sizeStr = fmt.Sprintf("%.2f GB", sizeGB)
}
// Format duration as "0m 0s 0ms"
totalMs := backup.BackupDurationMs
minutes := totalMs / (1000 * 60)
seconds := (totalMs % (1000 * 60)) / 1000
durationStr := fmt.Sprintf("%dm %ds", minutes, seconds)
message = fmt.Sprintf(
"Backup completed successfully in %s.\nCompressed backup size: %s",
durationStr,
sizeStr,
)
}
n.notificationSender.SendNotification(
&notifier,
title,
message,
)
}
}
func (n *BackuperNode) sendHeartbeat(backupNode *BackupNode) {
n.lastHeartbeat = time.Now().UTC()
backupNode.LastHeartbeat = time.Now().UTC()
if err := n.nodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *backupNode); err != nil {
n.logger.Error("Failed to send heartbeat", "error", err)
}
}

View File

@@ -1,4 +1,4 @@
package backups
package backuping
import (
"context"
@@ -8,17 +8,15 @@ import (
"time"
common "databasus-backend/internal/features/backups/backups/common"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_services "databasus-backend/internal/features/workspaces/services"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
cache_utils "databasus-backend/internal/util/cache"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
@@ -26,6 +24,7 @@ import (
)
func Test_BackupExecuted_NotificationSent(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
@@ -50,23 +49,19 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
t.Run("BackupFailed_FailNotificationSent", func(t *testing.T) {
mockNotificationSender := &MockNotificationSender{}
backupService := &BackupService{
databases.GetDatabaseService(),
storages.GetStorageService(),
backupRepository,
notifiers.GetNotifierService(),
mockNotificationSender,
backups_config.GetBackupConfigService(),
encryption_secrets.GetSecretKeyService(),
encryption.GetFieldEncryptor(),
&CreateFailedBackupUsecase{},
logger.GetLogger(),
[]BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
nil,
NewBackupContextManager(),
nil,
backuperNode := CreateTestBackuperNode()
backuperNode.notificationSender = mockNotificationSender
backuperNode.createBackupUseCase = &CreateFailedBackupUsecase{}
// Create a backup record directly that will be looked up by MakeBackup
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err := backupRepository.Save(backup)
assert.NoError(t, err)
// Set up expectations
mockNotificationSender.On("SendNotification",
@@ -79,7 +74,7 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
}),
).Once()
backupService.MakeBackup(database.ID, true)
backuperNode.MakeBackup(backup.ID, true)
// Verify all expectations were met
mockNotificationSender.AssertExpectations(t)
@@ -87,6 +82,19 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
t.Run("BackupSuccess_SuccessNotificationSent", func(t *testing.T) {
mockNotificationSender := &MockNotificationSender{}
backuperNode := CreateTestBackuperNode()
backuperNode.notificationSender = mockNotificationSender
backuperNode.createBackupUseCase = &CreateSuccessBackupUsecase{}
// Create a backup record directly that will be looked up by MakeBackup
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err := backupRepository.Save(backup)
assert.NoError(t, err)
// Set up expectations
mockNotificationSender.On("SendNotification",
@@ -99,25 +107,7 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
}),
).Once()
backupService := &BackupService{
databases.GetDatabaseService(),
storages.GetStorageService(),
backupRepository,
notifiers.GetNotifierService(),
mockNotificationSender,
backups_config.GetBackupConfigService(),
encryption_secrets.GetSecretKeyService(),
encryption.GetFieldEncryptor(),
&CreateSuccessBackupUsecase{},
logger.GetLogger(),
[]BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
nil,
NewBackupContextManager(),
nil,
}
backupService.MakeBackup(database.ID, true)
backuperNode.MakeBackup(backup.ID, true)
// Verify all expectations were met
mockNotificationSender.AssertExpectations(t)
@@ -125,23 +115,19 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
t.Run("BackupSuccess_VerifyNotificationContent", func(t *testing.T) {
mockNotificationSender := &MockNotificationSender{}
backupService := &BackupService{
databases.GetDatabaseService(),
storages.GetStorageService(),
backupRepository,
notifiers.GetNotifierService(),
mockNotificationSender,
backups_config.GetBackupConfigService(),
encryption_secrets.GetSecretKeyService(),
encryption.GetFieldEncryptor(),
&CreateSuccessBackupUsecase{},
logger.GetLogger(),
[]BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
nil,
NewBackupContextManager(),
nil,
backuperNode := CreateTestBackuperNode()
backuperNode.notificationSender = mockNotificationSender
backuperNode.createBackupUseCase = &CreateSuccessBackupUsecase{}
// Create a backup record directly that will be looked up by MakeBackup
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err := backupRepository.Save(backup)
assert.NoError(t, err)
// capture arguments
var capturedNotifier *notifiers.Notifier
@@ -158,7 +144,7 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
capturedMessage = args.Get(2).(string)
}).Once()
backupService.MakeBackup(database.ID, true)
backuperNode.MakeBackup(backup.ID, true)
// Verify expectations were met
mockNotificationSender.AssertExpectations(t)

View File

@@ -0,0 +1,77 @@
package backuping
import (
"databasus-backend/internal/config"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/usecases"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
workspaces_services "databasus-backend/internal/features/workspaces/services"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
"time"
"github.com/google/uuid"
)
var backupRepository = &backups_core.BackupRepository{}
var backupCancelManager = backups_cancellation.GetBackupCancelManager()
var nodesRegistry = &BackupNodesRegistry{
cache_utils.GetValkeyClient(),
logger.GetLogger(),
cache_utils.DefaultCacheTimeout,
cache_utils.NewPubSubManager(),
cache_utils.NewPubSubManager(),
}
func getNodeID() uuid.UUID {
nodeIDStr := config.GetEnv().NodeID
nodeID, err := uuid.Parse(nodeIDStr)
if err != nil {
logger.GetLogger().Error("Failed to parse node ID from config", "error", err)
panic(err)
}
return nodeID
}
var backuperNode = &BackuperNode{
databases.GetDatabaseService(),
encryption.GetFieldEncryptor(),
workspaces_services.GetWorkspaceService(),
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
notifiers.GetNotifierService(),
backupCancelManager,
nodesRegistry,
logger.GetLogger(),
usecases.GetCreateBackupUsecase(),
getNodeID(),
time.Time{},
}
var backupsScheduler = &BackupsScheduler{
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
backupCancelManager,
nodesRegistry,
time.Now().UTC(),
logger.GetLogger(),
make(map[uuid.UUID]BackupToNodeRelation),
backuperNode,
}
func GetBackupsScheduler() *BackupsScheduler {
return backupsScheduler
}
func GetBackuperNode() *BackuperNode {
return backuperNode
}

View File

@@ -0,0 +1,34 @@
package backuping
import (
"time"
"github.com/google/uuid"
)
type BackupNode struct {
ID uuid.UUID `json:"id"`
ThroughputMBs int `json:"throughputMBs"`
LastHeartbeat time.Time `json:"lastHeartbeat"`
}
type BackupNodeStats struct {
ID uuid.UUID `json:"id"`
ActiveBackups int `json:"activeBackups"`
}
type BackupSubmitMessage struct {
NodeID string `json:"nodeId"`
BackupID string `json:"backupId"`
IsCallNotifier bool `json:"isCallNotifier"`
}
type BackupCompletionMessage struct {
NodeID string `json:"nodeId"`
BackupID string `json:"backupId"`
}
type BackupToNodeRelation struct {
NodeID uuid.UUID `json:"nodeId"`
BackupsIDs []uuid.UUID `json:"backupsIds"`
}

View File

@@ -1,4 +1,4 @@
package backups
package backuping
import (
"databasus-backend/internal/features/notifiers"

View File

@@ -0,0 +1,448 @@
package backuping
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"time"
cache_utils "databasus-backend/internal/util/cache"
"github.com/google/uuid"
"github.com/valkey-io/valkey-go"
)
const (
nodeInfoKeyPrefix = "node:"
nodeInfoKeySuffix = ":info"
nodeActiveBackupsPrefix = "node:"
nodeActiveBackupsSuffix = ":active_backups"
backupSubmitChannel = "backup:submit"
backupCompletionChannel = "backup:completion"
)
// BackupNodesRegistry helps to sync backups scheduler and backup nodes.
// Features:
// - Track node availability and load level
// - Assign from scheduler to node backups needed to be processed
// - Notify scheduler from node about backup completion
type BackupNodesRegistry struct {
client valkey.Client
logger *slog.Logger
timeout time.Duration
pubsubBackups *cache_utils.PubSubManager
pubsubCompletions *cache_utils.PubSubManager
}
func (r *BackupNodesRegistry) GetAvailableNodes() ([]BackupNode, error) {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
var allKeys []string
cursor := uint64(0)
pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix
for {
result := r.client.Do(
ctx,
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(),
)
if result.Error() != nil {
return nil, fmt.Errorf("failed to scan node keys: %w", result.Error())
}
scanResult, err := result.AsScanEntry()
if err != nil {
return nil, fmt.Errorf("failed to parse scan result: %w", err)
}
allKeys = append(allKeys, scanResult.Elements...)
cursor = scanResult.Cursor
if cursor == 0 {
break
}
}
if len(allKeys) == 0 {
return []BackupNode{}, nil
}
keyDataMap, err := r.pipelineGetKeys(allKeys)
if err != nil {
return nil, fmt.Errorf("failed to pipeline get node keys: %w", err)
}
var nodes []BackupNode
for key, data := range keyDataMap {
var node BackupNode
if err := json.Unmarshal(data, &node); err != nil {
r.logger.Warn("Failed to unmarshal node data", "key", key, "error", err)
continue
}
nodes = append(nodes, node)
}
return nodes, nil
}
func (r *BackupNodesRegistry) GetBackupNodesStats() ([]BackupNodeStats, error) {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
var allKeys []string
cursor := uint64(0)
pattern := nodeActiveBackupsPrefix + "*" + nodeActiveBackupsSuffix
for {
result := r.client.Do(
ctx,
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(100).Build(),
)
if result.Error() != nil {
return nil, fmt.Errorf("failed to scan active backups keys: %w", result.Error())
}
scanResult, err := result.AsScanEntry()
if err != nil {
return nil, fmt.Errorf("failed to parse scan result: %w", err)
}
allKeys = append(allKeys, scanResult.Elements...)
cursor = scanResult.Cursor
if cursor == 0 {
break
}
}
if len(allKeys) == 0 {
return []BackupNodeStats{}, nil
}
keyDataMap, err := r.pipelineGetKeys(allKeys)
if err != nil {
return nil, fmt.Errorf("failed to pipeline get active backups keys: %w", err)
}
var stats []BackupNodeStats
for key, data := range keyDataMap {
nodeID := r.extractNodeIDFromKey(key, nodeActiveBackupsPrefix, nodeActiveBackupsSuffix)
count, err := r.parseIntFromBytes(data)
if err != nil {
r.logger.Warn("Failed to parse active backups count", "key", key, "error", err)
continue
}
stat := BackupNodeStats{
ID: nodeID,
ActiveBackups: int(count),
}
stats = append(stats, stat)
}
return stats, nil
}
func (r *BackupNodesRegistry) IncrementBackupsInProgress(nodeID string) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID, nodeActiveBackupsSuffix)
result := r.client.Do(ctx, r.client.B().Incr().Key(key).Build())
if result.Error() != nil {
return fmt.Errorf(
"failed to increment backups in progress for node %s: %w",
nodeID,
result.Error(),
)
}
return nil
}
func (r *BackupNodesRegistry) DecrementBackupsInProgress(nodeID string) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID, nodeActiveBackupsSuffix)
result := r.client.Do(ctx, r.client.B().Decr().Key(key).Build())
if result.Error() != nil {
return fmt.Errorf(
"failed to decrement backups in progress for node %s: %w",
nodeID,
result.Error(),
)
}
newValue, err := result.AsInt64()
if err != nil {
return fmt.Errorf("failed to parse decremented value for node %s: %w", nodeID, err)
}
if newValue < 0 {
setCtx, setCancel := context.WithTimeout(context.Background(), r.timeout)
r.client.Do(setCtx, r.client.B().Set().Key(key).Value("0").Build())
setCancel()
r.logger.Warn("Active backups counter went below 0, reset to 0", "nodeID", nodeID)
}
return nil
}
func (r *BackupNodesRegistry) HearthbeatNodeInRegistry(now time.Time, backupNode BackupNode) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
backupNode.LastHeartbeat = now
data, err := json.Marshal(backupNode)
if err != nil {
return fmt.Errorf("failed to marshal backup node: %w", err)
}
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, backupNode.ID.String(), nodeInfoKeySuffix)
result := r.client.Do(
ctx,
r.client.B().Set().Key(key).Value(string(data)).Build(),
)
if result.Error() != nil {
return fmt.Errorf("failed to register node %s: %w", backupNode.ID, result.Error())
}
return nil
}
func (r *BackupNodesRegistry) UnregisterNodeFromRegistry(backupNode BackupNode) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, backupNode.ID.String(), nodeInfoKeySuffix)
counterKey := fmt.Sprintf(
"%s%s%s",
nodeActiveBackupsPrefix,
backupNode.ID.String(),
nodeActiveBackupsSuffix,
)
result := r.client.Do(
ctx,
r.client.B().Del().Key(infoKey, counterKey).Build(),
)
if result.Error() != nil {
return fmt.Errorf("failed to unregister node %s: %w", backupNode.ID, result.Error())
}
r.logger.Info("Unregistered node from registry", "nodeID", backupNode.ID)
return nil
}
func (r *BackupNodesRegistry) AssignBackupToNode(
targetNodeID string,
backupID uuid.UUID,
isCallNotifier bool,
) error {
ctx := context.Background()
message := BackupSubmitMessage{
NodeID: targetNodeID,
BackupID: backupID.String(),
IsCallNotifier: isCallNotifier,
}
messageJSON, err := json.Marshal(message)
if err != nil {
return fmt.Errorf("failed to marshal backup submit message: %w", err)
}
err = r.pubsubBackups.Publish(ctx, backupSubmitChannel, string(messageJSON))
if err != nil {
return fmt.Errorf("failed to publish backup submit message: %w", err)
}
return nil
}
func (r *BackupNodesRegistry) SubscribeNodeForBackupsAssignment(
nodeID string,
handler func(backupID uuid.UUID, isCallNotifier bool),
) error {
ctx := context.Background()
wrappedHandler := func(message string) {
var msg BackupSubmitMessage
if err := json.Unmarshal([]byte(message), &msg); err != nil {
r.logger.Warn("Failed to unmarshal backup submit message", "error", err)
return
}
if msg.NodeID != nodeID {
return
}
backupID, err := uuid.Parse(msg.BackupID)
if err != nil {
r.logger.Warn(
"Failed to parse backup ID from message",
"backupId",
msg.BackupID,
"error",
err,
)
return
}
handler(backupID, msg.IsCallNotifier)
}
err := r.pubsubBackups.Subscribe(ctx, backupSubmitChannel, wrappedHandler)
if err != nil {
return fmt.Errorf("failed to subscribe to backup submit channel: %w", err)
}
r.logger.Info("Subscribed to backup submit channel", "nodeID", nodeID)
return nil
}
func (r *BackupNodesRegistry) UnsubscribeNodeForBackupsAssignments() error {
err := r.pubsubBackups.Close()
if err != nil {
return fmt.Errorf("failed to unsubscribe from backup submit channel: %w", err)
}
r.logger.Info("Unsubscribed from backup submit channel")
return nil
}
func (r *BackupNodesRegistry) PublishBackupCompletion(nodeID string, backupID uuid.UUID) error {
ctx := context.Background()
message := BackupCompletionMessage{
NodeID: nodeID,
BackupID: backupID.String(),
}
messageJSON, err := json.Marshal(message)
if err != nil {
return fmt.Errorf("failed to marshal backup completion message: %w", err)
}
err = r.pubsubCompletions.Publish(ctx, backupCompletionChannel, string(messageJSON))
if err != nil {
return fmt.Errorf("failed to publish backup completion message: %w", err)
}
return nil
}
func (r *BackupNodesRegistry) SubscribeForBackupsCompletions(
handler func(nodeID string, backupID uuid.UUID),
) error {
ctx := context.Background()
wrappedHandler := func(message string) {
var msg BackupCompletionMessage
if err := json.Unmarshal([]byte(message), &msg); err != nil {
r.logger.Warn("Failed to unmarshal backup completion message", "error", err)
return
}
backupID, err := uuid.Parse(msg.BackupID)
if err != nil {
r.logger.Warn(
"Failed to parse backup ID from completion message",
"backupId",
msg.BackupID,
"error",
err,
)
return
}
handler(msg.NodeID, backupID)
}
err := r.pubsubCompletions.Subscribe(ctx, backupCompletionChannel, wrappedHandler)
if err != nil {
return fmt.Errorf("failed to subscribe to backup completion channel: %w", err)
}
r.logger.Info("Subscribed to backup completion channel")
return nil
}
func (r *BackupNodesRegistry) UnsubscribeForBackupsCompletions() error {
err := r.pubsubCompletions.Close()
if err != nil {
return fmt.Errorf("failed to unsubscribe from backup completion channel: %w", err)
}
r.logger.Info("Unsubscribed from backup completion channel")
return nil
}
func (r *BackupNodesRegistry) extractNodeIDFromKey(key, prefix, suffix string) uuid.UUID {
nodeIDStr := strings.TrimPrefix(key, prefix)
nodeIDStr = strings.TrimSuffix(nodeIDStr, suffix)
nodeID, err := uuid.Parse(nodeIDStr)
if err != nil {
r.logger.Warn("Failed to parse node ID from key", "key", key, "error", err)
return uuid.Nil
}
return nodeID
}
func (r *BackupNodesRegistry) pipelineGetKeys(keys []string) (map[string][]byte, error) {
if len(keys) == 0 {
return make(map[string][]byte), nil
}
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
commands := make([]valkey.Completed, 0, len(keys))
for _, key := range keys {
commands = append(commands, r.client.B().Get().Key(key).Build())
}
results := r.client.DoMulti(ctx, commands...)
keyDataMap := make(map[string][]byte, len(keys))
for i, result := range results {
if result.Error() != nil {
r.logger.Warn("Failed to get key in pipeline", "key", keys[i], "error", result.Error())
continue
}
data, err := result.AsBytes()
if err != nil {
r.logger.Warn("Failed to parse key data in pipeline", "key", keys[i], "error", err)
continue
}
keyDataMap[keys[i]] = data
}
return keyDataMap, nil
}
func (r *BackupNodesRegistry) parseIntFromBytes(data []byte) (int64, error) {
str := string(data)
var count int64
_, err := fmt.Sscanf(str, "%d", &count)
if err != nil {
return 0, fmt.Errorf("failed to parse integer from bytes: %w", err)
}
return count, nil
}

View File

@@ -0,0 +1,904 @@
package backuping
import (
"context"
"testing"
"time"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/logger"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_HearthbeatNodeInRegistry_RegistersNodeWithTTL(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
defer cleanupTestNode(registry, node)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 1)
assert.Equal(t, node.ID, nodes[0].ID)
assert.Equal(t, node.ThroughputMBs, nodes[0].ThroughputMBs)
}
func Test_UnregisterNodeFromRegistry_RemovesNodeAndCounter(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
err = registry.UnregisterNodeFromRegistry(node)
assert.NoError(t, err)
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.Empty(t, nodes)
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Empty(t, stats)
}
func Test_GetAvailableNodes_ReturnsAllRegisteredNodes(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node1 := createTestBackupNode()
node2 := createTestBackupNode()
node3 := createTestBackupNode()
defer cleanupTestNode(registry, node1)
defer cleanupTestNode(registry, node2)
defer cleanupTestNode(registry, node3)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
assert.NoError(t, err)
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 3)
nodeIDs := make(map[uuid.UUID]bool)
for _, node := range nodes {
nodeIDs[node.ID] = true
}
assert.True(t, nodeIDs[node1.ID])
assert.True(t, nodeIDs[node2.ID])
assert.True(t, nodeIDs[node3.ID])
}
func Test_GetAvailableNodes_WhenNoNodesExist_ReturnsEmptySlice(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.NotNil(t, nodes)
assert.Empty(t, nodes)
}
func Test_IncrementBackupsInProgress_IncrementsCounter(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
defer cleanupTestNode(registry, node)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Len(t, stats, 1)
assert.Equal(t, node.ID, stats[0].ID)
assert.Equal(t, 1, stats[0].ActiveBackups)
err = registry.IncrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
stats, err = registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Len(t, stats, 1)
assert.Equal(t, 2, stats[0].ActiveBackups)
}
func Test_DecrementBackupsInProgress_DecrementsCounter(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
defer cleanupTestNode(registry, node)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Equal(t, 3, stats[0].ActiveBackups)
err = registry.DecrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
stats, err = registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Equal(t, 2, stats[0].ActiveBackups)
err = registry.DecrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
stats, err = registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Equal(t, 1, stats[0].ActiveBackups)
}
func Test_DecrementBackupsInProgress_WhenNegative_ResetsToZero(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
defer cleanupTestNode(registry, node)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
err = registry.DecrementBackupsInProgress(node.ID.String())
assert.NoError(t, err)
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Len(t, stats, 1)
assert.Equal(t, 0, stats[0].ActiveBackups)
}
func Test_GetBackupNodesStats_ReturnsStatsForAllNodes(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node1 := createTestBackupNode()
node2 := createTestBackupNode()
node3 := createTestBackupNode()
defer cleanupTestNode(registry, node1)
defer cleanupTestNode(registry, node2)
defer cleanupTestNode(registry, node3)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node1.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node2.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node2.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node3.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node3.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node3.ID.String())
assert.NoError(t, err)
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Len(t, stats, 3)
statsMap := make(map[uuid.UUID]int)
for _, stat := range stats {
statsMap[stat.ID] = stat.ActiveBackups
}
assert.Equal(t, 1, statsMap[node1.ID])
assert.Equal(t, 2, statsMap[node2.ID])
assert.Equal(t, 3, statsMap[node3.ID])
}
func Test_GetBackupNodesStats_WhenNoStats_ReturnsEmptySlice(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.NotNil(t, stats)
assert.Empty(t, stats)
}
func Test_MultipleNodes_RegisteredAndQueriedCorrectly(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node1 := createTestBackupNode()
node1.ThroughputMBs = 50
node2 := createTestBackupNode()
node2.ThroughputMBs = 100
node3 := createTestBackupNode()
node3.ThroughputMBs = 150
defer cleanupTestNode(registry, node1)
defer cleanupTestNode(registry, node2)
defer cleanupTestNode(registry, node3)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
assert.NoError(t, err)
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 3)
nodeMap := make(map[uuid.UUID]BackupNode)
for _, node := range nodes {
nodeMap[node.ID] = node
}
assert.Equal(t, 50, nodeMap[node1.ID].ThroughputMBs)
assert.Equal(t, 100, nodeMap[node2.ID].ThroughputMBs)
assert.Equal(t, 150, nodeMap[node3.ID].ThroughputMBs)
}
func Test_BackupCounters_TrackedSeparatelyPerNode(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node1 := createTestBackupNode()
node2 := createTestBackupNode()
defer cleanupTestNode(registry, node1)
defer cleanupTestNode(registry, node2)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node1.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node1.ID.String())
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node2.ID.String())
assert.NoError(t, err)
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Len(t, stats, 2)
statsMap := make(map[uuid.UUID]int)
for _, stat := range stats {
statsMap[stat.ID] = stat.ActiveBackups
}
assert.Equal(t, 2, statsMap[node1.ID])
assert.Equal(t, 1, statsMap[node2.ID])
err = registry.DecrementBackupsInProgress(node1.ID.String())
assert.NoError(t, err)
stats, err = registry.GetBackupNodesStats()
assert.NoError(t, err)
statsMap = make(map[uuid.UUID]int)
for _, stat := range stats {
statsMap[stat.ID] = stat.ActiveBackups
}
assert.Equal(t, 1, statsMap[node1.ID])
assert.Equal(t, 1, statsMap[node2.ID])
}
func Test_GetAvailableNodes_SkipsInvalidJsonData(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
defer cleanupTestNode(registry, node)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
defer cancel()
invalidKey := nodeInfoKeyPrefix + uuid.New().String() + nodeInfoKeySuffix
registry.client.Do(
ctx,
registry.client.B().Set().Key(invalidKey).Value("invalid json data").Build(),
)
defer func() {
cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), registry.timeout)
defer cleanupCancel()
registry.client.Do(cleanupCtx, registry.client.B().Del().Key(invalidKey).Build())
}()
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 1)
assert.Equal(t, node.ID, nodes[0].ID)
}
func Test_PipelineGetKeys_HandlesEmptyKeysList(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
keyDataMap, err := registry.pipelineGetKeys([]string{})
assert.NoError(t, err)
assert.NotNil(t, keyDataMap)
assert.Empty(t, keyDataMap)
}
func Test_HearthbeatNodeInRegistry_UpdatesLastHeartbeat(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
originalHeartbeat := node.LastHeartbeat
defer cleanupTestNode(registry, node)
time.Sleep(10 * time.Millisecond)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 1)
assert.True(t, nodes[0].LastHeartbeat.After(originalHeartbeat))
}
func createTestRegistry() *BackupNodesRegistry {
return &BackupNodesRegistry{
cache_utils.GetValkeyClient(),
logger.GetLogger(),
cache_utils.DefaultCacheTimeout,
cache_utils.NewPubSubManager(),
cache_utils.NewPubSubManager(),
}
}
func createTestBackupNode() BackupNode {
return BackupNode{
ID: uuid.New(),
ThroughputMBs: 100,
LastHeartbeat: time.Now().UTC(),
}
}
func cleanupTestNode(registry *BackupNodesRegistry, node BackupNode) {
registry.UnregisterNodeFromRegistry(node)
}
func Test_AssignBackupTonode_PublishesJsonMessageToChannel(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID := uuid.New()
err := registry.AssignBackupToNode(node.ID.String(), backupID, true)
assert.NoError(t, err)
}
func Test_SubscribeNodeForBackupsAssignment_ReceivesSubmittedBackupsForMatchingNode(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID := uuid.New()
defer registry.UnsubscribeNodeForBackupsAssignments()
receivedBackupID := make(chan uuid.UUID, 1)
handler := func(id uuid.UUID, isCallNotifier bool) {
receivedBackupID <- id
}
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.AssignBackupToNode(node.ID.String(), backupID, true)
assert.NoError(t, err)
select {
case received := <-receivedBackupID:
assert.Equal(t, backupID, received)
case <-time.After(2 * time.Second):
t.Fatal("Timeout waiting for backup message")
}
}
func Test_SubscribeNodeForBackupsAssignment_FiltersOutBackupsForDifferentNode(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node1 := createTestBackupNode()
node2 := createTestBackupNode()
backupID := uuid.New()
defer registry.UnsubscribeNodeForBackupsAssignments()
receivedBackupID := make(chan uuid.UUID, 1)
handler := func(id uuid.UUID, isCallNotifier bool) {
receivedBackupID <- id
}
err := registry.SubscribeNodeForBackupsAssignment(node1.ID.String(), handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.AssignBackupToNode(node2.ID.String(), backupID, false)
assert.NoError(t, err)
select {
case <-receivedBackupID:
t.Fatal("Should not receive backup for different node")
case <-time.After(500 * time.Millisecond):
}
}
func Test_SubscribeNodeForBackupsAssignment_ParsesJsonAndBackupIdCorrectly(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID1 := uuid.New()
backupID2 := uuid.New()
defer registry.UnsubscribeNodeForBackupsAssignments()
receivedBackups := make(chan uuid.UUID, 2)
handler := func(id uuid.UUID, isCallNotifier bool) {
receivedBackups <- id
}
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.AssignBackupToNode(node.ID.String(), backupID1, true)
assert.NoError(t, err)
err = registry.AssignBackupToNode(node.ID.String(), backupID2, false)
assert.NoError(t, err)
received1 := <-receivedBackups
received2 := <-receivedBackups
receivedIDs := []uuid.UUID{received1, received2}
assert.Contains(t, receivedIDs, backupID1)
assert.Contains(t, receivedIDs, backupID2)
}
func Test_SubscribeNodeForBackupsAssignment_HandlesInvalidJson(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
defer registry.UnsubscribeNodeForBackupsAssignments()
receivedBackupID := make(chan uuid.UUID, 1)
handler := func(id uuid.UUID, isCallNotifier bool) {
receivedBackupID <- id
}
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
ctx := context.Background()
err = registry.pubsubBackups.Publish(ctx, "backup:submit", "invalid json")
assert.NoError(t, err)
select {
case <-receivedBackupID:
t.Fatal("Should not receive backup for invalid JSON")
case <-time.After(500 * time.Millisecond):
}
}
func Test_UnsubscribeNodeForBackupsAssignments_StopsReceivingMessages(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID1 := uuid.New()
backupID2 := uuid.New()
receivedBackupID := make(chan uuid.UUID, 2)
handler := func(id uuid.UUID, isCallNotifier bool) {
receivedBackupID <- id
}
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.AssignBackupToNode(node.ID.String(), backupID1, true)
assert.NoError(t, err)
received := <-receivedBackupID
assert.Equal(t, backupID1, received)
err = registry.UnsubscribeNodeForBackupsAssignments()
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.AssignBackupToNode(node.ID.String(), backupID2, false)
assert.NoError(t, err)
select {
case <-receivedBackupID:
t.Fatal("Should not receive backup after unsubscribe")
case <-time.After(500 * time.Millisecond):
}
}
func Test_SubscribeNodeForBackupsAssignment_WhenAlreadySubscribed_ReturnsError(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
defer registry.UnsubscribeNodeForBackupsAssignments()
handler := func(id uuid.UUID, isCallNotifier bool) {}
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
assert.NoError(t, err)
err = registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
assert.Error(t, err)
assert.Contains(t, err.Error(), "already subscribed")
}
func Test_MultipleNodes_EachReceivesOnlyTheirBackups(t *testing.T) {
cache_utils.ClearAllCache()
registry1 := createTestRegistry()
registry2 := createTestRegistry()
registry3 := createTestRegistry()
node1 := createTestBackupNode()
node2 := createTestBackupNode()
node3 := createTestBackupNode()
backupID1 := uuid.New()
backupID2 := uuid.New()
backupID3 := uuid.New()
defer registry1.UnsubscribeNodeForBackupsAssignments()
defer registry2.UnsubscribeNodeForBackupsAssignments()
defer registry3.UnsubscribeNodeForBackupsAssignments()
receivedBackups1 := make(chan uuid.UUID, 3)
receivedBackups2 := make(chan uuid.UUID, 3)
receivedBackups3 := make(chan uuid.UUID, 3)
handler1 := func(id uuid.UUID, isCallNotifier bool) { receivedBackups1 <- id }
handler2 := func(id uuid.UUID, isCallNotifier bool) { receivedBackups2 <- id }
handler3 := func(id uuid.UUID, isCallNotifier bool) { receivedBackups3 <- id }
err := registry1.SubscribeNodeForBackupsAssignment(node1.ID.String(), handler1)
assert.NoError(t, err)
err = registry2.SubscribeNodeForBackupsAssignment(node2.ID.String(), handler2)
assert.NoError(t, err)
err = registry3.SubscribeNodeForBackupsAssignment(node3.ID.String(), handler3)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
submitRegistry := createTestRegistry()
err = submitRegistry.AssignBackupToNode(node1.ID.String(), backupID1, true)
assert.NoError(t, err)
err = submitRegistry.AssignBackupToNode(node2.ID.String(), backupID2, false)
assert.NoError(t, err)
err = submitRegistry.AssignBackupToNode(node3.ID.String(), backupID3, true)
assert.NoError(t, err)
select {
case received := <-receivedBackups1:
assert.Equal(t, backupID1, received)
case <-time.After(2 * time.Second):
t.Fatal("Node 1 timeout waiting for backup message")
}
select {
case received := <-receivedBackups2:
assert.Equal(t, backupID2, received)
case <-time.After(2 * time.Second):
t.Fatal("Node 2 timeout waiting for backup message")
}
select {
case received := <-receivedBackups3:
assert.Equal(t, backupID3, received)
case <-time.After(2 * time.Second):
t.Fatal("Node 3 timeout waiting for backup message")
}
select {
case <-receivedBackups1:
t.Fatal("Node 1 should not receive additional backups")
case <-time.After(300 * time.Millisecond):
}
select {
case <-receivedBackups2:
t.Fatal("Node 2 should not receive additional backups")
case <-time.After(300 * time.Millisecond):
}
select {
case <-receivedBackups3:
t.Fatal("Node 3 should not receive additional backups")
case <-time.After(300 * time.Millisecond):
}
}
func Test_PublishBackupCompletion_PublishesMessageToChannel(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID := uuid.New()
err := registry.PublishBackupCompletion(node.ID.String(), backupID)
assert.NoError(t, err)
}
func Test_SubscribeForBackupsCompletions_ReceivesCompletedBackups(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID := uuid.New()
defer registry.UnsubscribeForBackupsCompletions()
receivedBackupID := make(chan uuid.UUID, 1)
receivedNodeID := make(chan string, 1)
handler := func(nodeID string, backupID uuid.UUID) {
receivedNodeID <- nodeID
receivedBackupID <- backupID
}
err := registry.SubscribeForBackupsCompletions(handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.PublishBackupCompletion(node.ID.String(), backupID)
assert.NoError(t, err)
select {
case receivedNode := <-receivedNodeID:
assert.Equal(t, node.ID.String(), receivedNode)
case <-time.After(2 * time.Second):
t.Fatal("Timeout waiting for node ID")
}
select {
case received := <-receivedBackupID:
assert.Equal(t, backupID, received)
case <-time.After(2 * time.Second):
t.Fatal("Timeout waiting for backup completion message")
}
}
func Test_SubscribeForBackupsCompletions_ParsesJsonCorrectly(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID1 := uuid.New()
backupID2 := uuid.New()
defer registry.UnsubscribeForBackupsCompletions()
receivedBackups := make(chan uuid.UUID, 2)
handler := func(nodeID string, backupID uuid.UUID) {
receivedBackups <- backupID
}
err := registry.SubscribeForBackupsCompletions(handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.PublishBackupCompletion(node.ID.String(), backupID1)
assert.NoError(t, err)
err = registry.PublishBackupCompletion(node.ID.String(), backupID2)
assert.NoError(t, err)
received1 := <-receivedBackups
received2 := <-receivedBackups
receivedIDs := []uuid.UUID{received1, received2}
assert.Contains(t, receivedIDs, backupID1)
assert.Contains(t, receivedIDs, backupID2)
}
func Test_SubscribeForBackupsCompletions_HandlesInvalidJson(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
defer registry.UnsubscribeForBackupsCompletions()
receivedBackupID := make(chan uuid.UUID, 1)
handler := func(nodeID string, backupID uuid.UUID) {
receivedBackupID <- backupID
}
err := registry.SubscribeForBackupsCompletions(handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
ctx := context.Background()
err = registry.pubsubCompletions.Publish(ctx, "backup:completion", "invalid json")
assert.NoError(t, err)
select {
case <-receivedBackupID:
t.Fatal("Should not receive backup for invalid JSON")
case <-time.After(500 * time.Millisecond):
}
}
func Test_UnsubscribeForBackupsCompletions_StopsReceivingMessages(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
backupID1 := uuid.New()
backupID2 := uuid.New()
receivedBackupID := make(chan uuid.UUID, 2)
handler := func(nodeID string, backupID uuid.UUID) {
receivedBackupID <- backupID
}
err := registry.SubscribeForBackupsCompletions(handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.PublishBackupCompletion(node.ID.String(), backupID1)
assert.NoError(t, err)
received := <-receivedBackupID
assert.Equal(t, backupID1, received)
err = registry.UnsubscribeForBackupsCompletions()
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.PublishBackupCompletion(node.ID.String(), backupID2)
assert.NoError(t, err)
select {
case <-receivedBackupID:
t.Fatal("Should not receive backup after unsubscribe")
case <-time.After(500 * time.Millisecond):
}
}
func Test_SubscribeForBackupsCompletions_WhenAlreadySubscribed_ReturnsError(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
defer registry.UnsubscribeForBackupsCompletions()
handler := func(nodeID string, backupID uuid.UUID) {}
err := registry.SubscribeForBackupsCompletions(handler)
assert.NoError(t, err)
err = registry.SubscribeForBackupsCompletions(handler)
assert.Error(t, err)
assert.Contains(t, err.Error(), "already subscribed")
}
func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) {
cache_utils.ClearAllCache()
registry1 := createTestRegistry()
registry2 := createTestRegistry()
registry3 := createTestRegistry()
node1 := createTestBackupNode()
node2 := createTestBackupNode()
node3 := createTestBackupNode()
backupID1 := uuid.New()
backupID2 := uuid.New()
backupID3 := uuid.New()
defer registry1.UnsubscribeForBackupsCompletions()
defer registry2.UnsubscribeForBackupsCompletions()
defer registry3.UnsubscribeForBackupsCompletions()
receivedBackups1 := make(chan uuid.UUID, 3)
receivedBackups2 := make(chan uuid.UUID, 3)
receivedBackups3 := make(chan uuid.UUID, 3)
handler1 := func(nodeID string, backupID uuid.UUID) { receivedBackups1 <- backupID }
handler2 := func(nodeID string, backupID uuid.UUID) { receivedBackups2 <- backupID }
handler3 := func(nodeID string, backupID uuid.UUID) { receivedBackups3 <- backupID }
err := registry1.SubscribeForBackupsCompletions(handler1)
assert.NoError(t, err)
err = registry2.SubscribeForBackupsCompletions(handler2)
assert.NoError(t, err)
err = registry3.SubscribeForBackupsCompletions(handler3)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
publishRegistry := createTestRegistry()
err = publishRegistry.PublishBackupCompletion(node1.ID.String(), backupID1)
assert.NoError(t, err)
err = publishRegistry.PublishBackupCompletion(node2.ID.String(), backupID2)
assert.NoError(t, err)
err = publishRegistry.PublishBackupCompletion(node3.ID.String(), backupID3)
assert.NoError(t, err)
receivedAll1 := []uuid.UUID{}
receivedAll2 := []uuid.UUID{}
receivedAll3 := []uuid.UUID{}
for i := 0; i < 3; i++ {
select {
case received := <-receivedBackups1:
receivedAll1 = append(receivedAll1, received)
case <-time.After(2 * time.Second):
t.Fatal("Subscriber 1 timeout waiting for completion message")
}
}
for i := 0; i < 3; i++ {
select {
case received := <-receivedBackups2:
receivedAll2 = append(receivedAll2, received)
case <-time.After(2 * time.Second):
t.Fatal("Subscriber 2 timeout waiting for completion message")
}
}
for i := 0; i < 3; i++ {
select {
case received := <-receivedBackups3:
receivedAll3 = append(receivedAll3, received)
case <-time.After(2 * time.Second):
t.Fatal("Subscriber 3 timeout waiting for completion message")
}
}
assert.Contains(t, receivedAll1, backupID1)
assert.Contains(t, receivedAll1, backupID2)
assert.Contains(t, receivedAll1, backupID3)
assert.Contains(t, receivedAll2, backupID1)
assert.Contains(t, receivedAll2, backupID2)
assert.Contains(t, receivedAll2, backupID3)
assert.Contains(t, receivedAll3, backupID1)
assert.Contains(t, receivedAll3, backupID2)
assert.Contains(t, receivedAll3, backupID3)
}

View File

@@ -0,0 +1,600 @@
package backuping
import (
"context"
"databasus-backend/internal/config"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/storages"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/period"
"fmt"
"log/slog"
"time"
"github.com/google/uuid"
)
type BackupsScheduler struct {
backupRepository *backups_core.BackupRepository
backupConfigService *backups_config.BackupConfigService
storageService *storages.StorageService
backupCancelManager *backups_cancellation.BackupCancelManager
nodesRegistry *BackupNodesRegistry
lastBackupTime time.Time
logger *slog.Logger
backupToNodeRelations map[uuid.UUID]BackupToNodeRelation
backuperNode *BackuperNode
}
func (s *BackupsScheduler) Run(ctx context.Context) {
s.lastBackupTime = time.Now().UTC()
if config.GetEnv().IsManyNodesMode {
// wait other nodes to start
time.Sleep(1 * time.Minute)
}
if err := s.failBackupsInProgress(); err != nil {
s.logger.Error("Failed to fail backups in progress", "error", err)
panic(err)
}
if err := s.nodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted); err != nil {
s.logger.Error("Failed to subscribe to backup completions", "error", err)
panic(err)
}
defer func() {
if err := s.nodesRegistry.UnsubscribeForBackupsCompletions(); err != nil {
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
}
}()
if ctx.Err() != nil {
return
}
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.cleanOldBackups(); err != nil {
s.logger.Error("Failed to clean old backups", "error", err)
}
if err := s.checkDeadNodesAndFailBackups(); err != nil {
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
}
if err := s.runPendingBackups(); err != nil {
s.logger.Error("Failed to run pending backups", "error", err)
}
s.lastBackupTime = time.Now().UTC()
}
}
}
func (s *BackupsScheduler) IsSchedulerRunning() bool {
// if last backup time is more than 5 minutes ago, return false
return s.lastBackupTime.After(time.Now().UTC().Add(-5 * time.Minute))
}
func (s *BackupsScheduler) failBackupsInProgress() error {
backupsInProgress, err := s.backupRepository.FindByStatus(backups_core.BackupStatusInProgress)
if err != nil {
return err
}
fmt.Println("Backups in progress", len(backupsInProgress))
for _, backup := range backupsInProgress {
if err := s.backupCancelManager.CancelBackup(backup.ID); err != nil {
s.logger.Error(
"Failed to cancel backup via context manager",
"backupId",
backup.ID,
"error",
err,
)
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(backup.DatabaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
continue
}
failMessage := "Backup failed due to application restart"
backup.FailMessage = &failMessage
backup.Status = backups_core.BackupStatusFailed
backup.BackupSizeMb = 0
s.backuperNode.SendBackupNotification(
backupConfig,
backup,
backups_config.NotificationBackupFailed,
&failMessage,
)
if err := s.backupRepository.Save(backup); err != nil {
return err
}
}
return nil
}
func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool) {
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(databaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
return
}
if backupConfig.StorageID == nil {
s.logger.Error("Backup config storage ID is nil", "databaseId", databaseID)
return
}
leastBusyNodeID, err := s.calculateLeastBusyNode()
if err != nil {
s.logger.Error(
"Failed to calculate least busy node",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
return
}
backup := &backups_core.Backup{
DatabaseID: backupConfig.DatabaseID,
StorageID: *backupConfig.StorageID,
Status: backups_core.BackupStatusInProgress,
BackupSizeMb: 0,
CreatedAt: time.Now().UTC(),
}
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error(
"Failed to save backup",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
return
}
if err := s.nodesRegistry.IncrementBackupsInProgress(leastBusyNodeID.String()); err != nil {
s.logger.Error(
"Failed to increment backups in progress",
"nodeId",
leastBusyNodeID,
"backupId",
backup.ID,
"error",
err,
)
return
}
if err := s.nodesRegistry.AssignBackupToNode(leastBusyNodeID.String(), backup.ID, isCallNotifier); err != nil {
s.logger.Error(
"Failed to submit backup",
"nodeId",
leastBusyNodeID,
"backupId",
backup.ID,
"error",
err,
)
if decrementErr := s.nodesRegistry.DecrementBackupsInProgress(leastBusyNodeID.String()); decrementErr != nil {
s.logger.Error(
"Failed to decrement backups in progress after submit failure",
"nodeId",
leastBusyNodeID,
"error",
decrementErr,
)
}
return
}
if relation, exists := s.backupToNodeRelations[*leastBusyNodeID]; exists {
relation.BackupsIDs = append(relation.BackupsIDs, backup.ID)
s.backupToNodeRelations[*leastBusyNodeID] = relation
} else {
s.backupToNodeRelations[*leastBusyNodeID] = BackupToNodeRelation{
NodeID: *leastBusyNodeID,
BackupsIDs: []uuid.UUID{backup.ID},
}
}
s.logger.Info(
"Successfully triggered scheduled backup",
"databaseId",
backupConfig.DatabaseID,
"backupId",
backup.ID,
"nodeId",
leastBusyNodeID,
)
}
// GetRemainedBackupTryCount returns the number of remaining backup tries for a given backup.
// If the backup is not failed or the backup config does not allow retries, it returns 0.
// If the backup is failed and the backup config allows retries, it returns the number of remaining tries.
// If the backup is failed and the backup config does not allow retries, it returns 0.
func (s *BackupsScheduler) GetRemainedBackupTryCount(lastBackup *backups_core.Backup) int {
if lastBackup == nil {
return 0
}
if lastBackup.Status != backups_core.BackupStatusFailed {
return 0
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(lastBackup.DatabaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
return 0
}
if !backupConfig.IsRetryIfFailed {
return 0
}
maxFailedTriesCount := backupConfig.MaxFailedTriesCount
lastBackups, err := s.backupRepository.FindByDatabaseIDWithLimit(
lastBackup.DatabaseID,
maxFailedTriesCount,
)
if err != nil {
s.logger.Error("Failed to find last backups by database ID", "error", err)
return 0
}
lastFailedBackups := make([]*backups_core.Backup, 0)
for _, backup := range lastBackups {
if backup.Status == backups_core.BackupStatusFailed {
lastFailedBackups = append(lastFailedBackups, backup)
}
}
return maxFailedTriesCount - len(lastFailedBackups)
}
func (s *BackupsScheduler) cleanOldBackups() error {
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
backupStorePeriod := backupConfig.StorePeriod
if backupStorePeriod == period.PeriodForever {
continue
}
storeDuration := backupStorePeriod.ToDuration()
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
oldBackups, err := s.backupRepository.FindBackupsBeforeDate(
backupConfig.DatabaseID,
dateBeforeBackupsShouldBeDeleted,
)
if err != nil {
s.logger.Error(
"Failed to find old backups for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
continue
}
for _, backup := range oldBackups {
storage, err := s.storageService.GetStorageByID(backup.StorageID)
if err != nil {
s.logger.Error(
"Failed to get storage by ID",
"storageId",
backup.StorageID,
"error",
err,
)
continue
}
encryptor := encryption.GetFieldEncryptor()
err = storage.DeleteFile(encryptor, backup.ID)
if err != nil {
s.logger.Error("Failed to delete backup file", "backupId", backup.ID, "error", err)
}
if err := s.backupRepository.DeleteByID(backup.ID); err != nil {
s.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
continue
}
s.logger.Info(
"Deleted old backup",
"backupId",
backup.ID,
"databaseId",
backupConfig.DatabaseID,
)
}
}
return nil
}
func (s *BackupsScheduler) runPendingBackups() error {
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
if backupConfig.BackupInterval == nil {
continue
}
lastBackup, err := s.backupRepository.FindLastByDatabaseID(backupConfig.DatabaseID)
if err != nil {
s.logger.Error(
"Failed to get last backup for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
continue
}
var lastBackupTime *time.Time
if lastBackup != nil {
lastBackupTime = &lastBackup.CreatedAt
}
remainedBackupTryCount := s.GetRemainedBackupTryCount(lastBackup)
if backupConfig.BackupInterval.ShouldTriggerBackup(time.Now().UTC(), lastBackupTime) ||
remainedBackupTryCount > 0 {
s.logger.Info(
"Triggering scheduled backup",
"databaseId",
backupConfig.DatabaseID,
"intervalType",
backupConfig.BackupInterval.Interval,
)
s.StartBackup(backupConfig.DatabaseID, remainedBackupTryCount == 1)
continue
}
}
return nil
}
func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
nodes, err := s.nodesRegistry.GetAvailableNodes()
if err != nil {
return nil, fmt.Errorf("failed to get available nodes: %w", err)
}
if len(nodes) == 0 {
return nil, fmt.Errorf("no nodes available")
}
stats, err := s.nodesRegistry.GetBackupNodesStats()
if err != nil {
return nil, fmt.Errorf("failed to get backup nodes stats: %w", err)
}
statsMap := make(map[uuid.UUID]int)
for _, stat := range stats {
statsMap[stat.ID] = stat.ActiveBackups
}
var bestNode *BackupNode
var bestScore float64 = -1
now := time.Now().UTC()
for i := range nodes {
node := &nodes[i]
if now.Sub(node.LastHeartbeat) > 2*time.Minute {
continue
}
activeBackups := statsMap[node.ID]
var score float64
if node.ThroughputMBs > 0 {
score = float64(activeBackups) / float64(node.ThroughputMBs)
} else {
score = float64(activeBackups) * 1000
}
if bestNode == nil || score < bestScore {
bestNode = node
bestScore = score
}
}
if bestNode == nil {
return nil, fmt.Errorf("no suitable nodes available")
}
return &bestNode.ID, nil
}
func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUID) {
nodeID, err := uuid.Parse(nodeIDStr)
if err != nil {
s.logger.Error(
"Failed to parse node ID from completion message",
"nodeId",
nodeIDStr,
"error",
err,
)
return
}
relation, exists := s.backupToNodeRelations[nodeID]
if !exists {
s.logger.Warn(
"Received completion for unknown node",
"nodeId",
nodeID,
"backupId",
backupID,
)
return
}
newBackupIDs := make([]uuid.UUID, 0)
found := false
for _, id := range relation.BackupsIDs {
if id == backupID {
found = true
continue
}
newBackupIDs = append(newBackupIDs, id)
}
if !found {
s.logger.Warn(
"Backup not found in node's backup list",
"nodeId",
nodeID,
"backupId",
backupID,
)
return
}
if len(newBackupIDs) == 0 {
delete(s.backupToNodeRelations, nodeID)
} else {
relation.BackupsIDs = newBackupIDs
s.backupToNodeRelations[nodeID] = relation
}
if err := s.nodesRegistry.DecrementBackupsInProgress(nodeIDStr); err != nil {
s.logger.Error(
"Failed to decrement backups in progress",
"nodeId",
nodeID,
"backupId",
backupID,
"error",
err,
)
}
}
func (s *BackupsScheduler) checkDeadNodesAndFailBackups() error {
nodes, err := s.nodesRegistry.GetAvailableNodes()
if err != nil {
return fmt.Errorf("failed to get available nodes: %w", err)
}
aliveNodeIDs := make(map[uuid.UUID]bool)
now := time.Now().UTC()
for _, node := range nodes {
if now.Sub(node.LastHeartbeat) <= 2*time.Minute {
aliveNodeIDs[node.ID] = true
}
}
for nodeID, relation := range s.backupToNodeRelations {
if aliveNodeIDs[nodeID] {
continue
}
s.logger.Warn(
"Node is dead, failing its backups",
"nodeId",
nodeID,
"backupCount",
len(relation.BackupsIDs),
)
for _, backupID := range relation.BackupsIDs {
backup, err := s.backupRepository.FindByID(backupID)
if err != nil {
s.logger.Error(
"Failed to find backup for dead node",
"nodeId",
nodeID,
"backupId",
backupID,
"error",
err,
)
continue
}
failMessage := "Backup failed due to node unavailability"
backup.FailMessage = &failMessage
backup.Status = backups_core.BackupStatusFailed
backup.BackupSizeMb = 0
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error(
"Failed to save failed backup for dead node",
"nodeId",
nodeID,
"backupId",
backupID,
"error",
err,
)
continue
}
if err := s.nodesRegistry.DecrementBackupsInProgress(nodeID.String()); err != nil {
s.logger.Error(
"Failed to decrement backups in progress for dead node",
"nodeId",
nodeID,
"backupId",
backupID,
"error",
err,
)
}
s.logger.Info(
"Failed backup due to dead node",
"nodeId",
nodeID,
"backupId",
backupID,
)
}
delete(s.backupToNodeRelations, nodeID)
}
return nil
}

View File

@@ -1,6 +1,7 @@
package backups
package backuping
import (
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/intervals"
@@ -9,14 +10,21 @@ import (
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/period"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
func Test_RunPendingBackups_WhenLastBackupWasYesterday_CreatesNewBackup(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
// setup data
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
@@ -57,16 +65,16 @@ func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
assert.NoError(t, err)
// add old backup
backupRepository.Save(&Backup{
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusCompleted,
Status: backups_core.BackupStatusCompleted,
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
})
GetBackupBackgroundService().runPendingBackups()
GetBackupsScheduler().runPendingBackups()
// Wait for backup to complete (runs in goroutine)
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
@@ -80,7 +88,12 @@ func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
time.Sleep(200 * time.Millisecond)
}
func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
func Test_RunPendingBackups_WhenLastBackupWasRecentlyCompleted_SkipsBackup(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
// setup data
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
@@ -121,16 +134,16 @@ func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
assert.NoError(t, err)
// add recent backup (1 hour ago)
backupRepository.Save(&Backup{
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusCompleted,
Status: backups_core.BackupStatusCompleted,
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
})
GetBackupBackgroundService().runPendingBackups()
GetBackupsScheduler().runPendingBackups()
time.Sleep(100 * time.Millisecond)
@@ -143,7 +156,12 @@ func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
time.Sleep(200 * time.Millisecond)
}
func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T) {
func Test_RunPendingBackups_WhenLastBackupFailedAndRetriesDisabled_SkipsBackup(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
// setup data
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
@@ -187,17 +205,17 @@ func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T)
// add failed backup
failMessage := "backup failed"
backupRepository.Save(&Backup{
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusFailed,
Status: backups_core.BackupStatusFailed,
FailMessage: &failMessage,
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
})
GetBackupBackgroundService().runPendingBackups()
GetBackupsScheduler().runPendingBackups()
time.Sleep(100 * time.Millisecond)
@@ -210,7 +228,12 @@ func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T)
time.Sleep(200 * time.Millisecond)
}
func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
func Test_RunPendingBackups_WhenLastBackupFailedAndRetriesEnabled_CreatesNewBackup(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
// setup data
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
@@ -254,17 +277,17 @@ func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
// add failed backup
failMessage := "backup failed"
backupRepository.Save(&Backup{
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusFailed,
Status: backups_core.BackupStatusFailed,
FailMessage: &failMessage,
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
})
GetBackupBackgroundService().runPendingBackups()
GetBackupsScheduler().runPendingBackups()
// Wait for backup to complete (runs in goroutine)
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
@@ -278,7 +301,12 @@ func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
time.Sleep(200 * time.Millisecond)
}
func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *testing.T) {
func Test_RunPendingBackups_WhenFailedBackupsExceedMaxRetries_SkipsBackup(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
// setup data
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
@@ -322,19 +350,19 @@ func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *tes
failMessage := "backup failed"
for i := 0; i < 3; i++ {
backupRepository.Save(&Backup{
for range 3 {
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusFailed,
Status: backups_core.BackupStatusFailed,
FailMessage: &failMessage,
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
})
}
GetBackupBackgroundService().runPendingBackups()
GetBackupsScheduler().runPendingBackups()
time.Sleep(100 * time.Millisecond)
@@ -347,7 +375,12 @@ func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *tes
time.Sleep(200 * time.Millisecond)
}
func Test_MakeBackgroundBackupWhenBakupsDisabled_BackupSkipped(t *testing.T) {
func Test_RunPendingBackups_WhenBackupsDisabled_SkipsBackup(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
@@ -385,16 +418,16 @@ func Test_MakeBackgroundBackupWhenBakupsDisabled_BackupSkipped(t *testing.T) {
assert.NoError(t, err)
// add old backup that would trigger new backup if enabled
backupRepository.Save(&Backup{
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusCompleted,
Status: backups_core.BackupStatusCompleted,
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
})
GetBackupBackgroundService().runPendingBackups()
GetBackupsScheduler().runPendingBackups()
time.Sleep(100 * time.Millisecond)
@@ -405,3 +438,272 @@ func Test_MakeBackgroundBackupWhenBakupsDisabled_BackupSkipped(t *testing.T) {
// Wait for any cleanup operations to complete before defer cleanup runs
time.Sleep(200 * time.Millisecond)
}
func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegistry(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Register mock node without subscribing to backups (simulates node crash after registration)
mockNodeID := uuid.New()
err = CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC())
assert.NoError(t, err)
// Scheduler assigns backup to mock node
GetBackupsScheduler().StartBackup(database.ID, false)
time.Sleep(100 * time.Millisecond)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status)
// Verify Valkey counter was incremented when backup was assigned
stats, err := nodesRegistry.GetBackupNodesStats()
assert.NoError(t, err)
foundStat := false
for _, stat := range stats {
if stat.ID == mockNodeID {
assert.Equal(t, 1, stat.ActiveBackups)
foundStat = true
break
}
}
assert.True(t, foundStat, "Node stats should be present")
// Simulate node death by setting heartbeat older than 2-minute threshold
oldHeartbeat := time.Now().UTC().Add(-3 * time.Minute)
err = UpdateNodeHeartbeatDirectly(mockNodeID, 100, oldHeartbeat)
assert.NoError(t, err)
// Trigger dead node detection
err = GetBackupsScheduler().checkDeadNodesAndFailBackups()
assert.NoError(t, err)
// Verify backup was failed with appropriate error message
backups, err = backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
assert.Equal(t, backups_core.BackupStatusFailed, backups[0].Status)
assert.NotNil(t, backups[0].FailMessage)
assert.Contains(t, *backups[0].FailMessage, "node unavailability")
// Verify Valkey counter was decremented after backup failed
stats, err = nodesRegistry.GetBackupNodesStats()
assert.NoError(t, err)
for _, stat := range stats {
if stat.ID == mockNodeID {
assert.Equal(t, 0, stat.ActiveBackups)
}
}
// Node info should still exist in registry (not removed by checkDeadNodesAndFailBackups)
node, err := GetNodeFromRegistry(mockNodeID)
assert.NoError(t, err)
assert.NotNil(t, node)
assert.Equal(t, mockNodeID, node.ID)
time.Sleep(200 * time.Millisecond)
}
func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
t.Run("Nodes with same throughput", func(t *testing.T) {
cache_utils.ClearAllCache()
node1ID := uuid.New()
node2ID := uuid.New()
node3ID := uuid.New()
now := time.Now().UTC()
err := CreateMockNodeInRegistry(node1ID, 100, now)
assert.NoError(t, err)
err = CreateMockNodeInRegistry(node2ID, 100, now)
assert.NoError(t, err)
err = CreateMockNodeInRegistry(node3ID, 100, now)
assert.NoError(t, err)
for range 5 {
err = nodesRegistry.IncrementBackupsInProgress(node1ID.String())
assert.NoError(t, err)
}
for range 2 {
err = nodesRegistry.IncrementBackupsInProgress(node2ID.String())
assert.NoError(t, err)
}
for range 8 {
err = nodesRegistry.IncrementBackupsInProgress(node3ID.String())
assert.NoError(t, err)
}
leastBusyNodeID, err := GetBackupsScheduler().calculateLeastBusyNode()
assert.NoError(t, err)
assert.NotNil(t, leastBusyNodeID)
assert.Equal(t, node2ID, *leastBusyNodeID)
})
t.Run("Nodes with different throughput", func(t *testing.T) {
cache_utils.ClearAllCache()
node100MBsID := uuid.New()
node50MBsID := uuid.New()
now := time.Now().UTC()
err := CreateMockNodeInRegistry(node100MBsID, 100, now)
assert.NoError(t, err)
err = CreateMockNodeInRegistry(node50MBsID, 50, now)
assert.NoError(t, err)
for range 10 {
err = nodesRegistry.IncrementBackupsInProgress(node100MBsID.String())
assert.NoError(t, err)
}
err = nodesRegistry.IncrementBackupsInProgress(node50MBsID.String())
assert.NoError(t, err)
leastBusyNodeID, err := GetBackupsScheduler().calculateLeastBusyNode()
assert.NoError(t, err)
assert.NotNil(t, leastBusyNodeID)
assert.Equal(t, node50MBsID, *leastBusyNodeID)
})
}
func Test_FailBackupsInProgress_WhenSchedulerStarts_CancelsBackupsAndUpdatesStatus(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Create two in-progress backups that should be failed on scheduler restart
backup1 := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
BackupSizeMb: 10.5,
CreatedAt: time.Now().UTC().Add(-30 * time.Minute),
}
err = backupRepository.Save(backup1)
assert.NoError(t, err)
backup2 := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
BackupSizeMb: 5.2,
CreatedAt: time.Now().UTC().Add(-15 * time.Minute),
}
err = backupRepository.Save(backup2)
assert.NoError(t, err)
// Create a completed backup to verify it's not affected by failBackupsInProgress
completedBackup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 20.0,
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
}
err = backupRepository.Save(completedBackup)
assert.NoError(t, err)
// Trigger the scheduler's failBackupsInProgress logic
// This should cancel in-progress backups and mark them as failed
err = GetBackupsScheduler().failBackupsInProgress()
assert.NoError(t, err)
// Verify all backups exist and were processed correctly
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 3)
var failedCount int
var completedCount int
for _, backup := range backups {
switch backup.Status {
case backups_core.BackupStatusFailed:
failedCount++
// Verify fail message indicates application restart
assert.NotNil(t, backup.FailMessage)
assert.Equal(t, "Backup failed due to application restart", *backup.FailMessage)
// Verify backup size was reset to 0
assert.Equal(t, float64(0), backup.BackupSizeMb)
case backups_core.BackupStatusCompleted:
completedCount++
}
}
// Verify correct number of backups in each state
assert.Equal(t, 2, failedCount)
assert.Equal(t, 1, completedCount)
time.Sleep(200 * time.Millisecond)
}

View File

@@ -0,0 +1,206 @@
package backuping
import (
"context"
"fmt"
"testing"
"time"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/usecases"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_services "databasus-backend/internal/features/workspaces/services"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
func CreateTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
backups_config.GetBackupConfigController(),
)
return router
}
func CreateTestBackuperNode() *BackuperNode {
return &BackuperNode{
databases.GetDatabaseService(),
encryption.GetFieldEncryptor(),
workspaces_services.GetWorkspaceService(),
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
notifiers.GetNotifierService(),
backupCancelManager,
nodesRegistry,
logger.GetLogger(),
usecases.GetCreateBackupUsecase(),
uuid.New(),
time.Time{},
}
}
// WaitForBackupCompletion waits for a new backup to be created and completed (or failed)
// for the given database. It checks for backups with count greater than expectedInitialCount.
func WaitForBackupCompletion(
t *testing.T,
databaseID uuid.UUID,
expectedInitialCount int,
timeout time.Duration,
) {
deadline := time.Now().UTC().Add(timeout)
for time.Now().UTC().Before(deadline) {
backups, err := backupRepository.FindByDatabaseID(databaseID)
if err != nil {
t.Logf("WaitForBackupCompletion: error finding backups: %v", err)
time.Sleep(50 * time.Millisecond)
continue
}
t.Logf(
"WaitForBackupCompletion: found %d backups (expected > %d)",
len(backups),
expectedInitialCount,
)
if len(backups) > expectedInitialCount {
// Check if the newest backup has completed or failed
newestBackup := backups[0]
t.Logf("WaitForBackupCompletion: newest backup status: %s", newestBackup.Status)
if newestBackup.Status == backups_core.BackupStatusCompleted ||
newestBackup.Status == backups_core.BackupStatusFailed ||
newestBackup.Status == backups_core.BackupStatusCanceled {
t.Logf(
"WaitForBackupCompletion: backup finished with status %s",
newestBackup.Status,
)
return
}
}
time.Sleep(50 * time.Millisecond)
}
t.Logf("WaitForBackupCompletion: timeout waiting for backup to complete")
}
// StartBackuperNodeForTest starts a BackuperNode in a goroutine for testing.
// The node registers itself in the registry and subscribes to backup assignments.
// Returns a context cancel function that should be deferred to stop the node.
func StartBackuperNodeForTest(t *testing.T, backuperNode *BackuperNode) context.CancelFunc {
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
backuperNode.Run(ctx)
close(done)
}()
// Poll registry for node presence instead of fixed sleep
deadline := time.Now().UTC().Add(5 * time.Second)
for time.Now().UTC().Before(deadline) {
nodes, err := nodesRegistry.GetAvailableNodes()
if err == nil {
for _, node := range nodes {
if node.ID == backuperNode.nodeID {
t.Logf("BackuperNode registered in registry: %s", backuperNode.nodeID)
return func() {
cancel()
select {
case <-done:
t.Log("BackuperNode stopped gracefully")
case <-time.After(2 * time.Second):
t.Log("BackuperNode stop timeout")
}
}
}
}
}
time.Sleep(50 * time.Millisecond)
}
t.Fatalf("BackuperNode failed to register in registry within timeout")
return nil
}
// StopBackuperNodeForTest stops the BackuperNode by canceling its context.
// It waits for the node to unregister from the registry.
func StopBackuperNodeForTest(t *testing.T, cancel context.CancelFunc, backuperNode *BackuperNode) {
cancel()
// Wait for node to unregister from registry
deadline := time.Now().UTC().Add(2 * time.Second)
for time.Now().UTC().Before(deadline) {
nodes, err := nodesRegistry.GetAvailableNodes()
if err == nil {
found := false
for _, node := range nodes {
if node.ID == backuperNode.nodeID {
found = true
break
}
}
if !found {
t.Logf("BackuperNode unregistered from registry: %s", backuperNode.nodeID)
return
}
}
time.Sleep(50 * time.Millisecond)
}
t.Logf("BackuperNode stop completed for %s", backuperNode.nodeID)
}
func CreateMockNodeInRegistry(nodeID uuid.UUID, throughputMBs int, lastHeartbeat time.Time) error {
backupNode := BackupNode{
ID: nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: lastHeartbeat,
}
return nodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
}
func UpdateNodeHeartbeatDirectly(
nodeID uuid.UUID,
throughputMBs int,
lastHeartbeat time.Time,
) error {
backupNode := BackupNode{
ID: nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: lastHeartbeat,
}
return nodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
}
func GetNodeFromRegistry(nodeID uuid.UUID) (*BackupNode, error) {
nodes, err := nodesRegistry.GetAvailableNodes()
if err != nil {
return nil, err
}
for _, node := range nodes {
if node.ID == nodeID {
return &node, nil
}
}
return nil, fmt.Errorf("node not found")
}

View File

@@ -1,9 +1,8 @@
package backups
package backups_cancellation
import (
"context"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/logger"
"log/slog"
"sync"
@@ -12,22 +11,14 @@ import (
const backupCancelChannel = "backup:cancel"
type BackupContextManager struct {
type BackupCancelManager struct {
mu sync.RWMutex
cancelFuncs map[uuid.UUID]context.CancelFunc
pubsub *cache_utils.PubSubManager
logger *slog.Logger
}
func NewBackupContextManager() *BackupContextManager {
return &BackupContextManager{
cancelFuncs: make(map[uuid.UUID]context.CancelFunc),
pubsub: cache_utils.NewPubSubManager(),
logger: logger.GetLogger(),
}
}
func (m *BackupContextManager) StartSubscription() {
func (m *BackupCancelManager) StartSubscription() {
ctx := context.Background()
handler := func(message string) {
@@ -56,14 +47,14 @@ func (m *BackupContextManager) StartSubscription() {
}
}
func (m *BackupContextManager) RegisterBackup(backupID uuid.UUID, cancelFunc context.CancelFunc) {
func (m *BackupCancelManager) RegisterBackup(backupID uuid.UUID, cancelFunc context.CancelFunc) {
m.mu.Lock()
defer m.mu.Unlock()
m.cancelFuncs[backupID] = cancelFunc
m.logger.Debug("Registered backup", "backupID", backupID)
}
func (m *BackupContextManager) CancelBackup(backupID uuid.UUID) error {
func (m *BackupCancelManager) CancelBackup(backupID uuid.UUID) error {
ctx := context.Background()
err := m.pubsub.Publish(ctx, backupCancelChannel, backupID.String())
@@ -76,7 +67,7 @@ func (m *BackupContextManager) CancelBackup(backupID uuid.UUID) error {
return nil
}
func (m *BackupContextManager) UnregisterBackup(backupID uuid.UUID) {
func (m *BackupCancelManager) UnregisterBackup(backupID uuid.UUID) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.cancelFuncs, backupID)

View File

@@ -1,4 +1,4 @@
package backups
package backups_cancellation
import (
"context"
@@ -11,7 +11,7 @@ import (
)
func Test_RegisterBackup_BackupRegisteredSuccessfully(t *testing.T) {
manager := NewBackupContextManager()
manager := backupCancelManager
backupID := uuid.New()
_, cancel := context.WithCancel(context.Background())
@@ -26,7 +26,7 @@ func Test_RegisterBackup_BackupRegisteredSuccessfully(t *testing.T) {
}
func Test_UnregisterBackup_BackupUnregisteredSuccessfully(t *testing.T) {
manager := NewBackupContextManager()
manager := backupCancelManager
backupID := uuid.New()
_, cancel := context.WithCancel(context.Background())
@@ -42,7 +42,7 @@ func Test_UnregisterBackup_BackupUnregisteredSuccessfully(t *testing.T) {
}
func Test_CancelBackup_OnSameInstance_BackupCancelledViaPubSub(t *testing.T) {
manager := NewBackupContextManager()
manager := backupCancelManager
backupID := uuid.New()
ctx, cancel := context.WithCancel(context.Background())
@@ -75,8 +75,8 @@ func Test_CancelBackup_OnSameInstance_BackupCancelledViaPubSub(t *testing.T) {
}
func Test_CancelBackup_FromDifferentInstance_BackupCancelledOnRunningInstance(t *testing.T) {
manager1 := NewBackupContextManager()
manager2 := NewBackupContextManager()
manager1 := backupCancelManager
manager2 := backupCancelManager
backupID := uuid.New()
ctx, cancel := context.WithCancel(context.Background())
@@ -111,7 +111,7 @@ func Test_CancelBackup_FromDifferentInstance_BackupCancelledOnRunningInstance(t
}
func Test_CancelBackup_WhenBackupDoesNotExist_NoErrorReturned(t *testing.T) {
manager := NewBackupContextManager()
manager := backupCancelManager
manager.StartSubscription()
time.Sleep(100 * time.Millisecond)
@@ -122,7 +122,7 @@ func Test_CancelBackup_WhenBackupDoesNotExist_NoErrorReturned(t *testing.T) {
}
func Test_CancelBackup_WithMultipleBackups_AllBackupsCancelled(t *testing.T) {
manager := NewBackupContextManager()
manager := backupCancelManager
numBackups := 5
backupIDs := make([]uuid.UUID, numBackups)
@@ -165,7 +165,7 @@ func Test_CancelBackup_WithMultipleBackups_AllBackupsCancelled(t *testing.T) {
}
func Test_CancelBackup_AfterUnregister_BackupNotCancelled(t *testing.T) {
manager := NewBackupContextManager()
manager := backupCancelManager
backupID := uuid.New()
_, cancel := context.WithCancel(context.Background())

View File

@@ -0,0 +1,25 @@
package backups_cancellation
import (
"context"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/logger"
"sync"
"github.com/google/uuid"
)
var backupCancelManager = &BackupCancelManager{
sync.RWMutex{},
make(map[uuid.UUID]context.CancelFunc),
cache_utils.NewPubSubManager(),
logger.GetLogger(),
}
func GetBackupCancelManager() *BackupCancelManager {
return backupCancelManager
}
func SetupDependencies() {
backupCancelManager.StartSubscription()
}

View File

@@ -1,6 +1,7 @@
package backups
import (
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/databases"
users_middleware "databasus-backend/internal/features/users/middleware"
"fmt"
@@ -170,7 +171,7 @@ func (c *BackupController) CancelBackup(ctx *gin.Context) {
// @Description Generate a token for downloading a backup file (valid for 5 minutes)
// @Tags backups
// @Param id path string true "Backup ID"
// @Success 200 {object} GenerateDownloadTokenResponse
// @Success 200 {object} backups_download.GenerateDownloadTokenResponse
// @Failure 400
// @Failure 401
// @Router /backups/{id}/download-token [post]
@@ -276,7 +277,7 @@ type MakeBackupRequest struct {
}
func (c *BackupController) generateBackupFilename(
backup *Backup,
backup *backups_core.Backup,
database *databases.Database,
) string {
// Format timestamp as YYYY-MM-DD_HH-mm-ss

View File

@@ -18,7 +18,8 @@ import (
"databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups/download_token"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_download "databasus-backend/internal/features/backups/backups/download"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/postgresql"
@@ -478,7 +479,7 @@ func Test_GenerateDownloadToken_PermissionsEnforced(t *testing.T) {
)
if tt.expectSuccess {
var response GenerateDownloadTokenResponse
var response backups_download.GenerateDownloadTokenResponse
err := json.Unmarshal(testResp.Body, &response)
assert.NoError(t, err)
assert.NotEmpty(t, response.Token)
@@ -499,7 +500,7 @@ func Test_DownloadBackup_WithValidToken_Success(t *testing.T) {
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
// Generate download token
var tokenResponse GenerateDownloadTokenResponse
var tokenResponse backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
@@ -620,7 +621,7 @@ func Test_DownloadBackup_TokenUsedOnce_CannotReuseToken(t *testing.T) {
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
// Generate download token
var tokenResponse GenerateDownloadTokenResponse
var tokenResponse backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
@@ -683,7 +684,7 @@ func Test_DownloadBackup_WithDifferentBackupToken_Unauthorized(t *testing.T) {
backup2 := createTestBackup(database2, owner)
// Generate token for backup1
var tokenResponse GenerateDownloadTokenResponse
var tokenResponse backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
@@ -714,7 +715,7 @@ func Test_DownloadBackup_AuditLogWritten(t *testing.T) {
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
// Generate download token
var tokenResponse GenerateDownloadTokenResponse
var tokenResponse backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
@@ -806,7 +807,7 @@ func Test_DownloadBackup_ProperFilenameForPostgreSQL(t *testing.T) {
backup := createTestBackup(database, owner)
// Generate download token
var tokenResponse GenerateDownloadTokenResponse
var tokenResponse backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
@@ -897,22 +898,22 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
_, err = configService.SaveBackupConfig(config)
assert.NoError(t, err)
backup := &Backup{
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusInProgress,
Status: backups_core.BackupStatusInProgress,
BackupSizeMb: 0,
BackupDurationMs: 0,
CreatedAt: time.Now().UTC(),
}
repo := &BackupRepository{}
repo := &backups_core.BackupRepository{}
err = repo.Save(backup)
assert.NoError(t, err)
// Register a cancellable context for the backup
GetBackupService().backupContextManager.RegisterBackup(backup.ID, func() {})
GetBackupService().backupCancelManager.RegisterBackup(backup.ID, func() {})
resp := test_utils.MakePostRequest(
t,
@@ -1038,7 +1039,7 @@ func createTestDatabaseWithBackups(
workspace *workspaces_models.Workspace,
owner *users_dto.SignInResponseDTO,
router *gin.Engine,
) (*databases.Database, *Backup) {
) (*databases.Database, *backups_core.Backup) {
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
@@ -1064,7 +1065,7 @@ func createTestDatabaseWithBackups(
func createTestBackup(
database *databases.Database,
owner *users_dto.SignInResponseDTO,
) *Backup {
) *backups_core.Backup {
userService := users_services.GetUserService()
user, err := userService.GetUserFromToken(owner.Token)
if err != nil {
@@ -1076,17 +1077,17 @@ func createTestBackup(
panic("No storage found for workspace")
}
backup := &Backup{
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storages[0].ID,
Status: BackupStatusCompleted,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10.5,
BackupDurationMs: 1000,
CreatedAt: time.Now().UTC(),
}
repo := &BackupRepository{}
repo := &backups_core.BackupRepository{}
if err := repo.Save(backup); err != nil {
panic(err)
}
@@ -1116,7 +1117,7 @@ func createExpiredDownloadToken(backupID, userID uuid.UUID) string {
}
// Manually update the token to be expired
repo := &download_token.DownloadTokenRepository{}
repo := &backups_download.DownloadTokenRepository{}
downloadToken, err := repo.FindByToken(token)
if err != nil || downloadToken == nil {
panic(fmt.Sprintf("Failed to find generated token: %v", err))

View File

@@ -1,4 +1,4 @@
package backups
package backups_core
type BackupStatus string

View File

@@ -1,4 +1,4 @@
package backups
package backups_core
import (
"context"

View File

@@ -1,4 +1,4 @@
package backups
package backups_core
import (
backups_config "databasus-backend/internal/features/backups/config"

View File

@@ -1,4 +1,4 @@
package backups
package backups_core
import (
"databasus-backend/internal/storage"

View File

@@ -1,10 +1,11 @@
package backups
import (
"time"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups/download_token"
"databasus-backend/internal/features/backups/backups/backuping"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_download "databasus-backend/internal/features/backups/backups/download"
"databasus-backend/internal/features/backups/backups/usecases"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -16,50 +17,31 @@ import (
"databasus-backend/internal/util/logger"
)
var backupRepository = &BackupRepository{}
var backupRepository = &backups_core.BackupRepository{}
var backupContextManager = NewBackupContextManager()
var backupCancelManager = backups_cancellation.GetBackupCancelManager()
var backupService = &BackupService{
databases.GetDatabaseService(),
storages.GetStorageService(),
backupRepository,
notifiers.GetNotifierService(),
notifiers.GetNotifierService(),
backups_config.GetBackupConfigService(),
encryption_secrets.GetSecretKeyService(),
encryption.GetFieldEncryptor(),
usecases.GetCreateBackupUsecase(),
logger.GetLogger(),
[]BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
backupContextManager,
download_token.GetDownloadTokenService(),
}
var backupBackgroundService = &BackupBackgroundService{
backupService,
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
time.Now().UTC(),
logger.GetLogger(),
databaseService: databases.GetDatabaseService(),
storageService: storages.GetStorageService(),
backupRepository: backupRepository,
notifierService: notifiers.GetNotifierService(),
notificationSender: notifiers.GetNotifierService(),
backupConfigService: backups_config.GetBackupConfigService(),
secretKeyService: encryption_secrets.GetSecretKeyService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
createBackupUseCase: usecases.GetCreateBackupUsecase(),
logger: logger.GetLogger(),
backupRemoveListeners: []backups_core.BackupRemoveListener{},
workspaceService: workspaces_services.GetWorkspaceService(),
auditLogService: audit_logs.GetAuditLogService(),
backupCancelManager: backupCancelManager,
downloadTokenService: backups_download.GetDownloadTokenService(),
backupSchedulerService: backuping.GetBackupsScheduler(),
}
var backupController = &BackupController{
backupService,
}
func SetupDependencies() {
backups_config.
GetBackupConfigService().
SetDatabaseStorageChangeListener(backupService)
databases.GetDatabaseService().AddDbRemoveListener(backupService)
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
backupContextManager.StartSubscription()
backupService: backupService,
}
func GetBackupService() *BackupService {
@@ -70,10 +52,11 @@ func GetBackupController() *BackupController {
return backupController
}
func GetBackupBackgroundService() *BackupBackgroundService {
return backupBackgroundService
}
func SetupDependencies() {
backups_config.
GetBackupConfigService().
SetDatabaseStorageChangeListener(backupService)
func GetDownloadTokenBackgroundService() *download_token.DownloadTokenBackgroundService {
return download_token.GetDownloadTokenBackgroundService()
databases.GetDatabaseService().AddDbRemoveListener(backupService)
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
}

View File

@@ -0,0 +1,34 @@
package backups_download
import (
"context"
"log/slog"
"time"
)
type DownloadTokenBackgroundService struct {
downloadTokenService *DownloadTokenService
logger *slog.Logger
}
func (s *DownloadTokenBackgroundService) Run(ctx context.Context) {
s.logger.Info("Starting download token cleanup background service")
if ctx.Err() != nil {
return
}
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
s.logger.Error("Failed to clean expired download tokens", "error", err)
}
}
}
}

View File

@@ -1,4 +1,4 @@
package download_token
package backups_download
import (
"databasus-backend/internal/util/logger"

View File

@@ -0,0 +1,9 @@
package backups_download
import "github.com/google/uuid"
type GenerateDownloadTokenResponse struct {
Token string `json:"token"`
Filename string `json:"filename"`
BackupID uuid.UUID `json:"backupId"`
}

View File

@@ -1,4 +1,4 @@
package download_token
package backups_download
import (
"time"

View File

@@ -1,4 +1,4 @@
package download_token
package backups_download
import (
"crypto/rand"

View File

@@ -1,4 +1,4 @@
package download_token
package backups_download
import (
"errors"

View File

@@ -1,32 +0,0 @@
package download_token
import (
"databasus-backend/internal/config"
"log/slog"
"time"
)
type DownloadTokenBackgroundService struct {
downloadTokenService *DownloadTokenService
logger *slog.Logger
}
func (s *DownloadTokenBackgroundService) Run() {
s.logger.Info("Starting download token cleanup background service")
if config.IsShouldShutdown() {
return
}
for {
if config.IsShouldShutdown() {
return
}
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
s.logger.Error("Failed to clean expired download tokens", "error", err)
}
time.Sleep(1 * time.Minute)
}
}

View File

@@ -1,10 +1,9 @@
package backups
import (
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/encryption"
"io"
"github.com/google/uuid"
)
type GetBackupsRequest struct {
@@ -14,23 +13,17 @@ type GetBackupsRequest struct {
}
type GetBackupsResponse struct {
Backups []*Backup `json:"backups"`
Total int64 `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
Backups []*backups_core.Backup `json:"backups"`
Total int64 `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
}
type GenerateDownloadTokenResponse struct {
Token string `json:"token"`
Filename string `json:"filename"`
BackupID uuid.UUID `json:"backupId"`
}
type decryptionReaderCloser struct {
type DecryptionReaderCloser struct {
*encryption.DecryptionReader
baseReader io.ReadCloser
BaseReader io.ReadCloser
}
func (r *decryptionReaderCloser) Close() error {
return r.baseReader.Close()
func (r *DecryptionReaderCloser) Close() error {
return r.BaseReader.Close()
}

View File

@@ -1,18 +1,17 @@
package backups
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"log/slog"
"slices"
"strings"
"time"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups/download_token"
"databasus-backend/internal/features/backups/backups/backuping"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_download "databasus-backend/internal/features/backups/backups/download"
"databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -29,26 +28,27 @@ import (
type BackupService struct {
databaseService *databases.DatabaseService
storageService *storages.StorageService
backupRepository *BackupRepository
backupRepository *backups_core.BackupRepository
notifierService *notifiers.NotifierService
notificationSender NotificationSender
notificationSender backups_core.NotificationSender
backupConfigService *backups_config.BackupConfigService
secretKeyService *encryption_secrets.SecretKeyService
fieldEncryptor util_encryption.FieldEncryptor
createBackupUseCase CreateBackupUsecase
createBackupUseCase backups_core.CreateBackupUsecase
logger *slog.Logger
backupRemoveListeners []BackupRemoveListener
backupRemoveListeners []backups_core.BackupRemoveListener
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
backupContextManager *BackupContextManager
downloadTokenService *download_token.DownloadTokenService
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
backupCancelManager *backups_cancellation.BackupCancelManager
downloadTokenService *backups_download.DownloadTokenService
backupSchedulerService *backuping.BackupsScheduler
}
func (s *BackupService) AddBackupRemoveListener(listener BackupRemoveListener) {
func (s *BackupService) AddBackupRemoveListener(listener backups_core.BackupRemoveListener) {
s.backupRemoveListeners = append(s.backupRemoveListeners, listener)
}
@@ -91,7 +91,7 @@ func (s *BackupService) MakeBackupWithAuth(
return errors.New("insufficient permissions to create backup for this database")
}
go s.MakeBackup(databaseID, true)
s.backupSchedulerService.StartBackup(databaseID, true)
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Backup manually initiated for database: %s", database.Name),
@@ -175,7 +175,7 @@ func (s *BackupService) DeleteBackup(
return errors.New("insufficient permissions to delete backup for this database")
}
if backup.Status == BackupStatusInProgress {
if backup.Status == backups_core.BackupStatusInProgress {
return errors.New("backup is in progress")
}
@@ -192,260 +192,7 @@ func (s *BackupService) DeleteBackup(
return s.deleteBackup(backup)
}
func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
database, err := s.databaseService.GetDatabaseByID(databaseID)
if err != nil {
s.logger.Error("Failed to get database by ID", "error", err)
return
}
lastBackup, err := s.backupRepository.FindLastByDatabaseID(databaseID)
if err != nil {
s.logger.Error("Failed to find last backup by database ID", "error", err)
return
}
if lastBackup != nil && lastBackup.Status == BackupStatusInProgress {
s.logger.Error("Backup is in progress")
return
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(databaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
return
}
if backupConfig.StorageID == nil {
s.logger.Error("Backup config storage ID is not defined")
return
}
storage, err := s.storageService.GetStorageByID(*backupConfig.StorageID)
if err != nil {
s.logger.Error("Failed to get storage by ID", "error", err)
return
}
backup := &Backup{
DatabaseID: databaseID,
StorageID: storage.ID,
Status: BackupStatusInProgress,
BackupSizeMb: 0,
CreatedAt: time.Now().UTC(),
}
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error("Failed to save backup", "error", err)
return
}
start := time.Now().UTC()
backupProgressListener := func(
completedMBs float64,
) {
backup.BackupSizeMb = completedMBs
backup.BackupDurationMs = time.Since(start).Milliseconds()
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error("Failed to update backup progress", "error", err)
}
}
ctx, cancel := context.WithCancel(context.Background())
s.backupContextManager.RegisterBackup(backup.ID, cancel)
defer s.backupContextManager.UnregisterBackup(backup.ID)
backupMetadata, err := s.createBackupUseCase.Execute(
ctx,
backup.ID,
backupConfig,
database,
storage,
backupProgressListener,
)
if err != nil {
errMsg := err.Error()
// Check if backup was cancelled (not due to shutdown)
isCancelled := strings.Contains(errMsg, "backup cancelled") ||
strings.Contains(errMsg, "context canceled") ||
errors.Is(err, context.Canceled)
isShutdown := strings.Contains(errMsg, "shutdown")
if isCancelled && !isShutdown {
backup.Status = BackupStatusCanceled
backup.BackupDurationMs = time.Since(start).Milliseconds()
backup.BackupSizeMb = 0
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error("Failed to save cancelled backup", "error", err)
}
// Delete partial backup from storage
storage, storageErr := s.storageService.GetStorageByID(backup.StorageID)
if storageErr == nil {
if deleteErr := storage.DeleteFile(s.fieldEncryptor, backup.ID); deleteErr != nil {
s.logger.Error(
"Failed to delete partial backup file",
"backupId",
backup.ID,
"error",
deleteErr,
)
}
}
return
}
backup.FailMessage = &errMsg
backup.Status = BackupStatusFailed
backup.BackupDurationMs = time.Since(start).Milliseconds()
backup.BackupSizeMb = 0
if updateErr := s.databaseService.SetBackupError(databaseID, errMsg); updateErr != nil {
s.logger.Error(
"Failed to update database last backup time",
"databaseId",
databaseID,
"error",
updateErr,
)
}
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error("Failed to save backup", "error", err)
}
s.SendBackupNotification(
backupConfig,
backup,
backups_config.NotificationBackupFailed,
&errMsg,
)
return
}
backup.Status = BackupStatusCompleted
backup.BackupDurationMs = time.Since(start).Milliseconds()
// Update backup with encryption metadata if provided
if backupMetadata != nil {
backup.EncryptionSalt = backupMetadata.EncryptionSalt
backup.EncryptionIV = backupMetadata.EncryptionIV
backup.Encryption = backupMetadata.Encryption
}
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error("Failed to save backup", "error", err)
return
}
// Update database last backup time
now := time.Now().UTC()
if updateErr := s.databaseService.SetLastBackupTime(databaseID, now); updateErr != nil {
s.logger.Error(
"Failed to update database last backup time",
"databaseId",
databaseID,
"error",
updateErr,
)
}
if backup.Status != BackupStatusCompleted && !isLastTry {
return
}
s.SendBackupNotification(
backupConfig,
backup,
backups_config.NotificationBackupSuccess,
nil,
)
}
func (s *BackupService) SendBackupNotification(
backupConfig *backups_config.BackupConfig,
backup *Backup,
notificationType backups_config.BackupNotificationType,
errorMessage *string,
) {
database, err := s.databaseService.GetDatabaseByID(backupConfig.DatabaseID)
if err != nil {
return
}
workspace, err := s.workspaceService.GetWorkspaceByID(*database.WorkspaceID)
if err != nil {
return
}
for _, notifier := range database.Notifiers {
if !slices.Contains(
backupConfig.SendNotificationsOn,
notificationType,
) {
continue
}
title := ""
switch notificationType {
case backups_config.NotificationBackupFailed:
title = fmt.Sprintf(
"❌ Backup failed for database \"%s\" (workspace \"%s\")",
database.Name,
workspace.Name,
)
case backups_config.NotificationBackupSuccess:
title = fmt.Sprintf(
"✅ Backup completed for database \"%s\" (workspace \"%s\")",
database.Name,
workspace.Name,
)
}
message := ""
if errorMessage != nil {
message = *errorMessage
} else {
// Format size conditionally
var sizeStr string
if backup.BackupSizeMb < 1024 {
sizeStr = fmt.Sprintf("%.2f MB", backup.BackupSizeMb)
} else {
sizeGB := backup.BackupSizeMb / 1024
sizeStr = fmt.Sprintf("%.2f GB", sizeGB)
}
// Format duration as "0m 0s 0ms"
totalMs := backup.BackupDurationMs
minutes := totalMs / (1000 * 60)
seconds := (totalMs % (1000 * 60)) / 1000
durationStr := fmt.Sprintf("%dm %ds", minutes, seconds)
message = fmt.Sprintf(
"Backup completed successfully in %s.\nCompressed backup size: %s",
durationStr,
sizeStr,
)
}
s.notificationSender.SendNotification(
&notifier,
title,
message,
)
}
}
func (s *BackupService) GetBackup(backupID uuid.UUID) (*Backup, error) {
func (s *BackupService) GetBackup(backupID uuid.UUID) (*backups_core.Backup, error) {
return s.backupRepository.FindByID(backupID)
}
@@ -475,11 +222,11 @@ func (s *BackupService) CancelBackup(
return errors.New("insufficient permissions to cancel backup for this database")
}
if backup.Status != BackupStatusInProgress {
if backup.Status != backups_core.BackupStatusInProgress {
return errors.New("backup is not in progress")
}
if err := s.backupContextManager.CancelBackup(backupID); err != nil {
if err := s.backupCancelManager.CancelBackup(backupID); err != nil {
return err
}
@@ -499,7 +246,7 @@ func (s *BackupService) CancelBackup(
func (s *BackupService) GetBackupFile(
user *users_models.User,
backupID uuid.UUID,
) (io.ReadCloser, *Backup, *databases.Database, error) {
) (io.ReadCloser, *backups_core.Backup, *databases.Database, error) {
backup, err := s.backupRepository.FindByID(backupID)
if err != nil {
return nil, nil, nil, err
@@ -545,7 +292,7 @@ func (s *BackupService) GetBackupFile(
return reader, backup, database, nil
}
func (s *BackupService) deleteBackup(backup *Backup) error {
func (s *BackupService) deleteBackup(backup *backups_core.Backup) error {
for _, listener := range s.backupRemoveListeners {
if err := listener.OnBeforeBackupRemove(backup); err != nil {
return err
@@ -571,7 +318,7 @@ func (s *BackupService) deleteBackup(backup *Backup) error {
func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
dbBackupsInProgress, err := s.backupRepository.FindByDatabaseIdAndStatus(
databaseID,
BackupStatusInProgress,
backups_core.BackupStatusInProgress,
)
if err != nil {
return err
@@ -680,16 +427,16 @@ func (s *BackupService) getBackupReader(backupID uuid.UUID) (io.ReadCloser, erro
s.logger.Info("Returning encrypted backup with decryption", "backupId", backupID)
return &decryptionReaderCloser{
decryptionReader,
fileReader,
return &DecryptionReaderCloser{
DecryptionReader: decryptionReader,
BaseReader: fileReader,
}, nil
}
func (s *BackupService) GenerateDownloadToken(
user *users_models.User,
backupID uuid.UUID,
) (*GenerateDownloadTokenResponse, error) {
) (*backups_download.GenerateDownloadTokenResponse, error) {
backup, err := s.backupRepository.FindByID(backupID)
if err != nil {
return nil, err
@@ -725,20 +472,22 @@ func (s *BackupService) GenerateDownloadToken(
database.WorkspaceID,
)
return &GenerateDownloadTokenResponse{
return &backups_download.GenerateDownloadTokenResponse{
Token: token,
Filename: filename,
BackupID: backupID,
}, nil
}
func (s *BackupService) ValidateDownloadToken(token string) (*download_token.DownloadToken, error) {
func (s *BackupService) ValidateDownloadToken(
token string,
) (*backups_download.DownloadToken, error) {
return s.downloadTokenService.ValidateAndConsume(token)
}
func (s *BackupService) GetBackupFileWithoutAuth(
backupID uuid.UUID,
) (io.ReadCloser, *Backup, *databases.Database, error) {
) (io.ReadCloser, *backups_core.Backup, *databases.Database, error) {
backup, err := s.backupRepository.FindByID(backupID)
if err != nil {
return nil, nil, nil, err
@@ -759,7 +508,7 @@ func (s *BackupService) GetBackupFileWithoutAuth(
func (s *BackupService) WriteAuditLogForDownload(
userID uuid.UUID,
backup *Backup,
backup *backups_core.Backup,
database *databases.Database,
) {
s.auditLogService.WriteAuditLog(
@@ -774,7 +523,7 @@ func (s *BackupService) WriteAuditLogForDownload(
}
func (s *BackupService) generateBackupFilename(
backup *Backup,
backup *backups_core.Backup,
database *databases.Database,
) string {
timestamp := backup.CreatedAt.Format("2006-01-02_15-04-05")

View File

@@ -4,6 +4,7 @@ import (
"testing"
"time"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
@@ -58,9 +59,9 @@ func WaitForBackupCompletion(
newestBackup := backups[0]
t.Logf("WaitForBackupCompletion: newest backup status: %s", newestBackup.Status)
if newestBackup.Status == BackupStatusCompleted ||
newestBackup.Status == BackupStatusFailed ||
newestBackup.Status == BackupStatusCanceled {
if newestBackup.Status == backups_core.BackupStatusCompleted ||
newestBackup.Status == backups_core.BackupStatusFailed ||
newestBackup.Status == backups_core.BackupStatusCanceled {
t.Logf(
"WaitForBackupCompletion: backup finished with status %s",
newestBackup.Status,

View File

@@ -1,7 +1,7 @@
package healthcheck_attempt
import (
"databasus-backend/internal/config"
"context"
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
"log/slog"
"time"
@@ -13,18 +13,19 @@ type HealthcheckAttemptBackgroundService struct {
logger *slog.Logger
}
func (s *HealthcheckAttemptBackgroundService) Run() {
func (s *HealthcheckAttemptBackgroundService) Run(ctx context.Context) {
// first healthcheck immediately
s.checkDatabases()
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for range ticker.C {
if config.IsShouldShutdown() {
break
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.checkDatabases()
}
s.checkDatabases()
}
}

View File

@@ -1,6 +1,7 @@
package restores
import (
"context"
"databasus-backend/internal/features/restores/enums"
"log/slog"
)
@@ -10,7 +11,7 @@ type RestoreBackgroundService struct {
logger *slog.Logger
}
func (s *RestoreBackgroundService) Run() {
func (s *RestoreBackgroundService) Run(ctx context.Context) {
if err := s.failRestoresInProgress(); err != nil {
s.logger.Error("Failed to fail restores in progress", "error", err)
panic(err)

View File

@@ -19,6 +19,7 @@ import (
"databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/mysql"
@@ -274,7 +275,7 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
var backup *backups.Backup
var backup *backups_core.Backup
var request RestoreBackupRequest
if tc.dbType == databases.DatabaseTypePostgres {
@@ -321,7 +322,7 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
}
// Set huge backup size (10 TB) that would fail disk validation if checked
repo := &backups.BackupRepository{}
repo := &backups_core.BackupRepository{}
backup.BackupSizeMb = 10485760.0
err := repo.Save(backup)
assert.NoError(t, err)
@@ -368,7 +369,7 @@ func createTestDatabaseWithBackupForRestore(
workspace *workspaces_models.Workspace,
owner *users_dto.SignInResponseDTO,
router *gin.Engine,
) (*databases.Database, *backups.Backup) {
) (*databases.Database, *backups_core.Backup) {
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
@@ -504,7 +505,7 @@ func createTestStorage(workspaceID uuid.UUID) *storages.Storage {
func createTestBackup(
database *databases.Database,
owner *users_dto.SignInResponseDTO,
) *backups.Backup {
) *backups_core.Backup {
fieldEncryptor := util_encryption.GetFieldEncryptor()
userService := users_services.GetUserService()
user, err := userService.GetUserFromToken(owner.Token)
@@ -517,17 +518,17 @@ func createTestBackup(
panic("No storage found for workspace")
}
backup := &backups.Backup{
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storages[0].ID,
Status: backups.BackupStatusCompleted,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10.5,
BackupDurationMs: 1000,
CreatedAt: time.Now().UTC(),
}
repo := &backups.BackupRepository{}
repo := &backups_core.BackupRepository{}
if err := repo.Save(backup); err != nil {
panic(err)
}

View File

@@ -1,7 +1,7 @@
package models
import (
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/restores/enums"
"time"
@@ -13,7 +13,7 @@ type Restore struct {
Status enums.RestoreStatus `json:"status" gorm:"column:status;type:text;not null"`
BackupID uuid.UUID `json:"backupId" gorm:"column:backup_id;type:uuid;not null"`
Backup *backups.Backup
Backup *backups_core.Backup
FailMessage *string `json:"failMessage" gorm:"column:fail_message"`

View File

@@ -3,6 +3,7 @@ package restores
import (
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/disk"
@@ -36,7 +37,7 @@ type RestoreService struct {
diskService *disk.DiskService
}
func (s *RestoreService) OnBeforeBackupRemove(backup *backups.Backup) error {
func (s *RestoreService) OnBeforeBackupRemove(backup *backups_core.Backup) error {
restores, err := s.restoreRepository.FindByBackupID(backup.ID)
if err != nil {
return err
@@ -153,10 +154,10 @@ func (s *RestoreService) RestoreBackupWithAuth(
}
func (s *RestoreService) RestoreBackup(
backup *backups.Backup,
backup *backups_core.Backup,
requestDTO RestoreBackupRequest,
) error {
if backup.Status != backups.BackupStatusCompleted {
if backup.Status != backups_core.BackupStatusCompleted {
return errors.New("backup is not completed")
}
@@ -370,7 +371,7 @@ func (s *RestoreService) validateVersionCompatibility(
}
func (s *RestoreService) validateDiskSpace(
backup *backups.Backup,
backup *backups_core.Backup,
requestDTO RestoreBackupRequest,
) error {
// Only validate disk space for PostgreSQL when file-based restore is needed:

View File

@@ -18,7 +18,7 @@ import (
"github.com/klauspost/compress/zstd"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -40,7 +40,7 @@ func (uc *RestoreMariadbBackupUsecase) Execute(
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
restore models.Restore,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
) error {
if originalDB.Type != databases.DatabaseTypeMariadb {
@@ -99,7 +99,7 @@ func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
mariadbBin string,
args []string,
password string,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
mdbConfig *mariadbtypes.MariadbDatabase,
) error {
@@ -163,7 +163,7 @@ func (uc *RestoreMariadbBackupUsecase) executeMariadbRestore(
args []string,
myCnfFile string,
backupReader io.ReadCloser,
backup *backups.Backup,
backup *backups_core.Backup,
) error {
fullArgs := append([]string{"--defaults-file=" + myCnfFile}, args...)
@@ -226,7 +226,7 @@ func (uc *RestoreMariadbBackupUsecase) executeMariadbRestore(
func (uc *RestoreMariadbBackupUsecase) setupDecryption(
reader io.Reader,
backup *backups.Backup,
backup *backups_core.Backup,
) (io.Reader, error) {
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
return nil, fmt.Errorf("backup is encrypted but missing encryption metadata")

View File

@@ -14,7 +14,7 @@ import (
"time"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -40,7 +40,7 @@ func (uc *RestoreMongodbBackupUsecase) Execute(
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
restore models.Restore,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
) error {
if originalDB.Type != databases.DatabaseTypeMongodb {
@@ -124,7 +124,7 @@ func (uc *RestoreMongodbBackupUsecase) buildMongorestoreArgs(
func (uc *RestoreMongodbBackupUsecase) restoreFromStorage(
mongorestoreBin string,
args []string,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
) error {
ctx, cancel := context.WithTimeout(context.Background(), restoreTimeout)
@@ -166,7 +166,7 @@ func (uc *RestoreMongodbBackupUsecase) executeMongoRestore(
mongorestoreBin string,
args []string,
backupReader io.ReadCloser,
backup *backups.Backup,
backup *backups_core.Backup,
) error {
cmd := exec.CommandContext(ctx, mongorestoreBin, args...)
@@ -231,7 +231,7 @@ func (uc *RestoreMongodbBackupUsecase) executeMongoRestore(
func (uc *RestoreMongodbBackupUsecase) setupDecryption(
reader io.Reader,
backup *backups.Backup,
backup *backups_core.Backup,
) (io.Reader, error) {
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
return nil, errors.New("encrypted backup missing salt or IV")

View File

@@ -18,7 +18,7 @@ import (
"github.com/klauspost/compress/zstd"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -40,7 +40,7 @@ func (uc *RestoreMysqlBackupUsecase) Execute(
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
restore models.Restore,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
) error {
if originalDB.Type != databases.DatabaseTypeMysql {
@@ -98,7 +98,7 @@ func (uc *RestoreMysqlBackupUsecase) restoreFromStorage(
mysqlBin string,
args []string,
password string,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
myConfig *mysqltypes.MysqlDatabase,
) error {
@@ -154,7 +154,7 @@ func (uc *RestoreMysqlBackupUsecase) executeMysqlRestore(
args []string,
myCnfFile string,
backupReader io.ReadCloser,
backup *backups.Backup,
backup *backups_core.Backup,
) error {
fullArgs := append([]string{"--defaults-file=" + myCnfFile}, args...)
@@ -217,7 +217,7 @@ func (uc *RestoreMysqlBackupUsecase) executeMysqlRestore(
func (uc *RestoreMysqlBackupUsecase) setupDecryption(
reader io.Reader,
backup *backups.Backup,
backup *backups_core.Backup,
) (io.Reader, error) {
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
return nil, fmt.Errorf("backup is encrypted but missing encryption metadata")

View File

@@ -15,7 +15,7 @@ import (
"time"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -39,7 +39,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
restore models.Restore,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
isExcludeExtensions bool,
) error {
@@ -86,7 +86,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
func (uc *RestorePostgresqlBackupUsecase) restoreCustomType(
originalDB *databases.Database,
pgBin string,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
pg *pgtypes.PostgresqlDatabase,
isExcludeExtensions bool,
@@ -113,7 +113,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreCustomType(
func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
originalDB *databases.Database,
pgBin string,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
pg *pgtypes.PostgresqlDatabase,
) error {
@@ -321,7 +321,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
func (uc *RestorePostgresqlBackupUsecase) restoreViaFile(
originalDB *databases.Database,
pgBin string,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
pg *pgtypes.PostgresqlDatabase,
isExcludeExtensions bool,
@@ -371,7 +371,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
pgBin string,
args []string,
password string,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
pgConfig *pgtypes.PostgresqlDatabase,
isExcludeExtensions bool,
@@ -469,7 +469,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
// downloadBackupToTempFile downloads backup data from storage to a temporary file
func (uc *RestorePostgresqlBackupUsecase) downloadBackupToTempFile(
ctx context.Context,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
) (string, func(), error) {
// Create temporary directory for backup data

View File

@@ -3,7 +3,7 @@ package usecases
import (
"errors"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/restores/models"
@@ -26,7 +26,7 @@ func (uc *RestoreBackupUsecase) Execute(
restore models.Restore,
originalDB *databases.Database,
restoringToDB *databases.Database,
backup *backups.Backup,
backup *backups_core.Backup,
storage *storages.Storage,
isExcludeExtensions bool,
) error {

View File

@@ -1,13 +1,14 @@
package system_healthcheck
import (
"databasus-backend/internal/features/backups/backups"
"databasus-backend/internal/features/backups/backups/backuping"
"databasus-backend/internal/features/disk"
)
var healthcheckService = &HealthcheckService{
disk.GetDiskService(),
backups.GetBackupBackgroundService(),
backuping.GetBackupsScheduler(),
backuping.GetBackuperNode(),
}
var healthcheckController = &HealthcheckController{
healthcheckService,

View File

@@ -1,7 +1,8 @@
package system_healthcheck
import (
"databasus-backend/internal/features/backups/backups"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups/backuping"
"databasus-backend/internal/features/disk"
"databasus-backend/internal/storage"
"errors"
@@ -9,7 +10,8 @@ import (
type HealthcheckService struct {
diskService *disk.DiskService
backupBackgroundService *backups.BackupBackgroundService
backupBackgroundService *backuping.BackupsScheduler
backuperNode *backuping.BackuperNode
}
func (s *HealthcheckService) IsHealthy() error {
@@ -29,8 +31,16 @@ func (s *HealthcheckService) IsHealthy() error {
return errors.New("cannot connect to the database")
}
if !s.backupBackgroundService.IsBackupsWorkerRunning() {
return errors.New("backups are not running for more than 5 minutes")
if config.GetEnv().IsPrimaryNode {
if !s.backupBackgroundService.IsSchedulerRunning() {
return errors.New("backups are not running for more than 5 minutes")
}
}
if config.GetEnv().IsBackupNode {
if !s.backuperNode.IsBackuperRunning() {
return errors.New("backuper node is not running for more than 5 minutes")
}
}
return nil

View File

@@ -17,7 +17,7 @@ import (
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
mariadbtypes "databasus-backend/internal/features/databases/databases/mariadb"
@@ -189,7 +189,7 @@ func testMariadbBackupRestoreForVersion(
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb_mariadb"
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -286,7 +286,7 @@ func testMariadbBackupRestoreWithEncryptionForVersion(
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_config.BackupEncryptionEncrypted, backup.Encryption)
newDBName := "restoreddb_mariadb_encrypted"
@@ -394,7 +394,7 @@ func testMariadbBackupRestoreWithReadOnlyUserForVersion(
createBackupViaAPI(t, router, updatedDatabase.ID, user.Token)
backup := waitForBackupCompletion(t, router, updatedDatabase.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb_mariadb_readonly"
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))

View File

@@ -19,7 +19,7 @@ import (
"go.mongodb.org/mongo-driver/mongo/options"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
mongodbtypes "databasus-backend/internal/features/databases/databases/mongodb"
@@ -161,7 +161,7 @@ func testMongodbBackupRestoreForVersion(
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb_mongodb_" + uuid.New().String()[:8]
@@ -239,7 +239,7 @@ func testMongodbBackupRestoreWithEncryptionForVersion(
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_config.BackupEncryptionEncrypted, backup.Encryption)
newDBName := "restoreddb_mongodb_enc_" + uuid.New().String()[:8]
@@ -328,7 +328,7 @@ func testMongodbBackupRestoreWithReadOnlyUserForVersion(
createBackupViaAPI(t, router, updatedDatabase.ID, user.Token)
backup := waitForBackupCompletion(t, router, updatedDatabase.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb_mongodb_ro_" + uuid.New().String()[:8]

View File

@@ -17,7 +17,7 @@ import (
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
mysqltypes "databasus-backend/internal/features/databases/databases/mysql"
@@ -164,7 +164,7 @@ func testMysqlBackupRestoreForVersion(t *testing.T, mysqlVersion tools.MysqlVers
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb_mysql"
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -261,7 +261,7 @@ func testMysqlBackupRestoreWithEncryptionForVersion(
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_config.BackupEncryptionEncrypted, backup.Encryption)
newDBName := "restoreddb_mysql_encrypted"
@@ -369,7 +369,7 @@ func testMysqlBackupRestoreWithReadOnlyUserForVersion(
createBackupViaAPI(t, router, updatedDatabase.ID, user.Token)
backup := waitForBackupCompletion(t, router, updatedDatabase.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb_mysql_readonly"
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))

View File

@@ -18,6 +18,7 @@ import (
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
pgtypes "databasus-backend/internal/features/databases/databases/postgresql"
@@ -190,7 +191,7 @@ func Test_BackupAndRestoreSupabase_PublicSchemaOnly_RestoreIsSuccessful(t *testi
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
_, err = supabaseDB.Exec(fmt.Sprintf(`DELETE FROM public.%s`, tableName))
assert.NoError(t, err)
@@ -410,7 +411,7 @@ func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string, cp
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := fmt.Sprintf("restoreddb_%s_cpu%d_%s", pgVersion, cpuCount, uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -527,7 +528,7 @@ func testSchemaSelectionAllSchemasForVersion(t *testing.T, pgVersion string, por
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := fmt.Sprintf("restored_all_schemas_%s_%s", pgVersion, uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -655,7 +656,7 @@ func testBackupRestoreWithExcludeExtensionsForVersion(t *testing.T, pgVersion st
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := fmt.Sprintf("restored_exclude_ext_%s_%s", pgVersion, uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -789,7 +790,7 @@ func testBackupRestoreWithoutExcludeExtensionsForVersion(
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := fmt.Sprintf("restored_with_ext_%s_%s", pgVersion, uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -928,7 +929,7 @@ func testBackupRestoreWithReadOnlyUserForVersion(t *testing.T, pgVersion string,
createBackupViaAPI(t, router, updatedDatabase.ID, user.Token)
backup := waitForBackupCompletion(t, router, updatedDatabase.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := fmt.Sprintf("restoreddb_readonly_%s", uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -1048,7 +1049,7 @@ func testSchemaSelectionOnlySpecifiedSchemasForVersion(
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := fmt.Sprintf("restored_specific_schemas_%s_%s", pgVersion, uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
@@ -1161,7 +1162,7 @@ func testBackupRestoreWithEncryptionForVersion(t *testing.T, pgVersion string, p
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_config.BackupEncryptionEncrypted, backup.Encryption)
newDBName := fmt.Sprintf("restoreddb_encrypted_%s", uuid.New().String()[:8])
@@ -1242,7 +1243,7 @@ func waitForBackupCompletion(
databaseID uuid.UUID,
token string,
timeout time.Duration,
) *backups.Backup {
) *backups_core.Backup {
startTime := time.Now()
pollInterval := 500 * time.Millisecond
@@ -1263,10 +1264,10 @@ func waitForBackupCompletion(
if len(response.Backups) > 0 {
backup := response.Backups[0]
if backup.Status == backups.BackupStatusCompleted {
if backup.Status == backups_core.BackupStatusCompleted {
return backup
}
if backup.Status == backups.BackupStatusFailed {
if backup.Status == backups_core.BackupStatusFailed {
failMsg := "unknown error"
if backup.FailMessage != nil {
failMsg = *backup.FailMessage

View File

@@ -0,0 +1,22 @@
package tests
import (
"os"
"testing"
"databasus-backend/internal/features/backups/backups/backuping"
cache_utils "databasus-backend/internal/util/cache"
)
func TestMain(m *testing.M) {
cache_utils.ClearAllCache()
backuperNode := backuping.CreateTestBackuperNode()
cancel := backuping.StartBackuperNodeForTest(&testing.T{}, backuperNode)
exitCode := m.Run()
backuping.StopBackuperNodeForTest(&testing.T{}, cancel, backuperNode)
os.Exit(exitCode)
}

View File

@@ -41,6 +41,10 @@ func getCache() valkey.Client {
return valkeyClient
}
func GetValkeyClient() valkey.Client {
return getCache()
}
func TestCacheConnection() {
// Get Valkey client from cache package
client := getCache()