diff --git a/AGENTS.md b/AGENTS.md index ce8f642..bbe0448 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -237,6 +237,66 @@ func (c *ProjectController) extractProjectID(ctx *gin.Context) uuid.UUID { --- +### Boolean Naming + +**Always prefix boolean variables with verbs like `is`, `has`, `was`, `should`, `can`, etc.** + +This makes the code more readable and clearly indicates that the variable represents a true/false state. + +#### Good Examples: + +```go +type User struct { + IsActive bool + IsVerified bool + HasAccess bool + WasNotified bool +} + +type BackupConfig struct { + IsEnabled bool + ShouldCompress bool + CanRetry bool +} + +// Variables +isInProgress := true +wasCompleted := false +hasPermission := checkPermissions() +``` + +#### Bad Examples: + +```go +type User struct { + Active bool // Should be: IsActive + Verified bool // Should be: IsVerified + Access bool // Should be: HasAccess +} + +type BackupConfig struct { + Enabled bool // Should be: IsEnabled + Compress bool // Should be: ShouldCompress + Retry bool // Should be: CanRetry +} + +// Variables +inProgress := true // Should be: isInProgress +completed := false // Should be: wasCompleted +permission := true // Should be: hasPermission +``` + +#### Common Boolean Prefixes: + +- **is** - current state (IsActive, IsValid, IsEnabled) +- **has** - possession or presence (HasAccess, HasPermission, HasError) +- **was** - past state (WasCompleted, WasNotified, WasDeleted) +- **should** - intention or recommendation (ShouldRetry, ShouldCompress) +- **can** - capability or permission (CanRetry, CanDelete, CanEdit) +- **will** - future state (WillExpire, WillRetry) + +--- + ### Comments #### Guidelines @@ -489,6 +549,134 @@ func GetOrderRepository() *repositories.OrderRepository { } ``` +#### SetupDependencies() Pattern + +**All `SetupDependencies()` functions must use sync.Once to ensure idempotent execution.** + +This pattern allows `SetupDependencies()` to be safely called multiple times (especially in tests) while ensuring the actual setup logic executes only once. + +**Implementation Pattern:** + +```go +package feature + +import ( + "sync" + "sync/atomic" + "databasus-backend/internal/util/logger" +) + +var ( + setupOnce sync.Once + isSetup atomic.Bool +) + +func SetupDependencies() { + wasAlreadySetup := isSetup.Load() + + setupOnce.Do(func() { + // Initialize dependencies here + someService.SetDependency(otherService) + anotherService.AddListener(listener) + + isSetup.Store(true) + }) + + if wasAlreadySetup { + logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call") + } +} +``` + +**Why This Pattern:** + +- **Tests can call multiple times**: Test setup often calls `SetupDependencies()` multiple times without issues +- **Thread-safe**: Works correctly with concurrent calls (nanoseconds or seconds apart) +- **Idempotent**: Subsequent calls are safe, only log warning +- **No panics**: Does not break tests or production code on multiple calls + +**Key Points:** + +1. Check `isSetup.Load()` **before** calling `Do()` to detect previous executions +2. Set `isSetup.Store(true)` **inside** the `Do()` closure after setup completes +3. Log warning if already setup (helps identify unnecessary duplicate calls) +4. All setup logic must be inside the `Do()` closure + +--- + +### Background Services + +**All background service `Run()` methods must panic if called multiple times to prevent corrupted states.** + +Background services run infinite loops and must never be started twice on the same instance. Multiple calls indicate a serious bug that would cause duplicate goroutines, resource leaks, and data corruption. + +**Implementation Pattern:** + +```go +package feature + +import ( + "context" + "fmt" + "sync" + "sync/atomic" +) + +type BackgroundService struct { + // ... existing fields ... + runOnce sync.Once + hasRun atomic.Bool +} + +func (s *BackgroundService) Run(ctx context.Context) { + wasAlreadyRun := s.hasRun.Load() + + s.runOnce.Do(func() { + s.hasRun.Store(true) + + // Existing infinite loop logic + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + s.doWork() + } + } + }) + + if wasAlreadyRun { + panic(fmt.Sprintf("%T.Run() called multiple times", s)) + } +} +``` + +**Why Panic Instead of Warning:** + +- **Prevents corruption**: Multiple `Run()` calls would create duplicate goroutines consuming resources +- **Fails fast**: Catches critical bugs immediately in tests and production +- **Clear indication**: Panic clearly indicates a serious programming error +- **Applies everywhere**: Same protection in tests and production + +**When This Applies:** + +- All background services with infinite loops +- Registry services (BackupNodesRegistry, RestoreNodesRegistry) +- Scheduler services (BackupsScheduler, RestoresScheduler) +- Worker nodes (BackuperNode, RestorerNode) +- Cleanup services (AuditLogBackgroundService, DownloadTokenBackgroundService) + +**Key Points:** + +1. Check `hasRun.Load()` **before** calling `Do()` to detect previous executions +2. Set `hasRun.Store(true)` **inside** the `Do()` closure before starting work +3. **Always panic** if already run (never just log warning) +4. All run logic must be inside the `Do()` closure +5. This pattern is **thread-safe** for any timing (concurrent or sequential calls) + --- ### Migrations diff --git a/backend/internal/features/audit_logs/background_service.go b/backend/internal/features/audit_logs/background_service.go index a5cdb6a..5fc4812 100644 --- a/backend/internal/features/audit_logs/background_service.go +++ b/backend/internal/features/audit_logs/background_service.go @@ -2,34 +2,50 @@ package audit_logs import ( "context" + "fmt" "log/slog" + "sync" + "sync/atomic" "time" ) type AuditLogBackgroundService struct { auditLogService *AuditLogService logger *slog.Logger + + runOnce sync.Once + hasRun atomic.Bool } func (s *AuditLogBackgroundService) Run(ctx context.Context) { - s.logger.Info("Starting audit log cleanup background service") + wasAlreadyRun := s.hasRun.Load() - if ctx.Err() != nil { - return - } + s.runOnce.Do(func() { + s.hasRun.Store(true) - ticker := time.NewTicker(1 * time.Hour) - defer ticker.Stop() + s.logger.Info("Starting audit log cleanup background service") - for { - select { - case <-ctx.Done(): + if ctx.Err() != nil { return - case <-ticker.C: - if err := s.cleanOldAuditLogs(); err != nil { - s.logger.Error("Failed to clean old audit logs", "error", err) + } + + ticker := time.NewTicker(1 * time.Hour) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := s.cleanOldAuditLogs(); err != nil { + s.logger.Error("Failed to clean old audit logs", "error", err) + } } } + }) + + if wasAlreadyRun { + panic(fmt.Sprintf("%T.Run() called multiple times", s)) } } diff --git a/backend/internal/features/audit_logs/di.go b/backend/internal/features/audit_logs/di.go index 2f0ae83..135945d 100644 --- a/backend/internal/features/audit_logs/di.go +++ b/backend/internal/features/audit_logs/di.go @@ -1,6 +1,9 @@ package audit_logs import ( + "sync" + "sync/atomic" + users_services "databasus-backend/internal/features/users/services" "databasus-backend/internal/util/logger" ) @@ -14,8 +17,10 @@ var auditLogController = &AuditLogController{ auditLogService, } var auditLogBackgroundService = &AuditLogBackgroundService{ - auditLogService, - logger.GetLogger(), + auditLogService: auditLogService, + logger: logger.GetLogger(), + runOnce: sync.Once{}, + hasRun: atomic.Bool{}, } func GetAuditLogService() *AuditLogService { @@ -30,8 +35,23 @@ func GetAuditLogBackgroundService() *AuditLogBackgroundService { return auditLogBackgroundService } +var ( + setupOnce sync.Once + isSetup atomic.Bool +) + func SetupDependencies() { - users_services.GetUserService().SetAuditLogWriter(auditLogService) - users_services.GetSettingsService().SetAuditLogWriter(auditLogService) - users_services.GetManagementService().SetAuditLogWriter(auditLogService) + wasAlreadySetup := isSetup.Load() + + setupOnce.Do(func() { + users_services.GetUserService().SetAuditLogWriter(auditLogService) + users_services.GetSettingsService().SetAuditLogWriter(auditLogService) + users_services.GetManagementService().SetAuditLogWriter(auditLogService) + + isSetup.Store(true) + }) + + if wasAlreadySetup { + logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call") + } } diff --git a/backend/internal/features/backups/backups/backuping/backuper.go b/backend/internal/features/backups/backups/backuping/backuper.go index 2c40a4d..5a17e16 100644 --- a/backend/internal/features/backups/backups/backuping/backuper.go +++ b/backend/internal/features/backups/backups/backuping/backuper.go @@ -2,6 +2,17 @@ package backuping import ( "context" + "errors" + "fmt" + "log/slog" + "slices" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" + "databasus-backend/internal/config" backups_core "databasus-backend/internal/features/backups/backups/core" backups_config "databasus-backend/internal/features/backups/config" @@ -10,14 +21,6 @@ import ( tasks_cancellation "databasus-backend/internal/features/tasks/cancellation" workspaces_services "databasus-backend/internal/features/workspaces/services" util_encryption "databasus-backend/internal/util/encryption" - "errors" - "fmt" - "log/slog" - "slices" - "strings" - "time" - - "github.com/google/uuid" ) const ( @@ -40,66 +43,79 @@ type BackuperNode struct { nodeID uuid.UUID lastHeartbeat time.Time + + runOnce sync.Once + hasRun atomic.Bool } func (n *BackuperNode) Run(ctx context.Context) { - n.lastHeartbeat = time.Now().UTC() + wasAlreadyRun := n.hasRun.Load() - throughputMBs := config.GetEnv().NodeNetworkThroughputMBs + n.runOnce.Do(func() { + n.hasRun.Store(true) - backupNode := BackupNode{ - ID: n.nodeID, - ThroughputMBs: throughputMBs, - LastHeartbeat: time.Now().UTC(), - } + n.lastHeartbeat = time.Now().UTC() - if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil { - n.logger.Error("Failed to register node in registry", "error", err) - panic(err) - } + throughputMBs := config.GetEnv().NodeNetworkThroughputMBs - backupHandler := func(backupID uuid.UUID, isCallNotifier bool) { - n.MakeBackup(backupID, isCallNotifier) - if err := n.backupNodesRegistry.PublishBackupCompletion(n.nodeID, backupID); err != nil { - n.logger.Error( - "Failed to publish backup completion", - "error", - err, - "backupID", - backupID, - ) + backupNode := BackupNode{ + ID: n.nodeID, + ThroughputMBs: throughputMBs, + LastHeartbeat: time.Now().UTC(), } - } - err := n.backupNodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID, backupHandler) - if err != nil { - n.logger.Error("Failed to subscribe to backup assignments", "error", err) - panic(err) - } - defer func() { - if err := n.backupNodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil { - n.logger.Error("Failed to unsubscribe from backup assignments", "error", err) + if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil { + n.logger.Error("Failed to register node in registry", "error", err) + panic(err) } - }() - ticker := time.NewTicker(heartbeatTickerInterval) - defer ticker.Stop() - - n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs) - - for { - select { - case <-ctx.Done(): - n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID) - - if err := n.backupNodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil { - n.logger.Error("Failed to unregister node from registry", "error", err) + backupHandler := func(backupID uuid.UUID, isCallNotifier bool) { + n.MakeBackup(backupID, isCallNotifier) + if err := n.backupNodesRegistry.PublishBackupCompletion(n.nodeID, backupID); err != nil { + n.logger.Error( + "Failed to publish backup completion", + "error", + err, + "backupID", + backupID, + ) } - - return - case <-ticker.C: - n.sendHeartbeat(&backupNode) } + + err := n.backupNodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID, backupHandler) + if err != nil { + n.logger.Error("Failed to subscribe to backup assignments", "error", err) + panic(err) + } + defer func() { + if err := n.backupNodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil { + n.logger.Error("Failed to unsubscribe from backup assignments", "error", err) + } + }() + + ticker := time.NewTicker(heartbeatTickerInterval) + defer ticker.Stop() + + n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs) + + for { + select { + case <-ctx.Done(): + n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID) + + if err := n.backupNodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil { + n.logger.Error("Failed to unregister node from registry", "error", err) + } + + return + case <-ticker.C: + n.sendHeartbeat(&backupNode) + } + } + }) + + if wasAlreadyRun { + panic(fmt.Sprintf("%T.Run() called multiple times", n)) } } diff --git a/backend/internal/features/backups/backups/backuping/di.go b/backend/internal/features/backups/backups/backuping/di.go index 12fa010..86b3893 100644 --- a/backend/internal/features/backups/backups/backuping/di.go +++ b/backend/internal/features/backups/backups/backuping/di.go @@ -1,6 +1,12 @@ package backuping import ( + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" + backups_core "databasus-backend/internal/features/backups/backups/core" "databasus-backend/internal/features/backups/backups/usecases" backups_config "databasus-backend/internal/features/backups/config" @@ -12,9 +18,6 @@ import ( cache_utils "databasus-backend/internal/util/cache" "databasus-backend/internal/util/encryption" "databasus-backend/internal/util/logger" - "time" - - "github.com/google/uuid" ) var backupRepository = &backups_core.BackupRepository{} @@ -22,11 +25,13 @@ var backupRepository = &backups_core.BackupRepository{} var taskCancelManager = tasks_cancellation.GetTaskCancelManager() var backupNodesRegistry = &BackupNodesRegistry{ - cache_utils.GetValkeyClient(), - logger.GetLogger(), - cache_utils.DefaultCacheTimeout, - cache_utils.NewPubSubManager(), - cache_utils.NewPubSubManager(), + client: cache_utils.GetValkeyClient(), + logger: logger.GetLogger(), + timeout: cache_utils.DefaultCacheTimeout, + pubsubBackups: cache_utils.NewPubSubManager(), + pubsubCompletions: cache_utils.NewPubSubManager(), + runOnce: sync.Once{}, + hasRun: atomic.Bool{}, } func getNodeID() uuid.UUID { @@ -34,19 +39,21 @@ func getNodeID() uuid.UUID { } var backuperNode = &BackuperNode{ - databases.GetDatabaseService(), - encryption.GetFieldEncryptor(), - workspaces_services.GetWorkspaceService(), - backupRepository, - backups_config.GetBackupConfigService(), - storages.GetStorageService(), - notifiers.GetNotifierService(), - taskCancelManager, - backupNodesRegistry, - logger.GetLogger(), - usecases.GetCreateBackupUsecase(), - getNodeID(), - time.Time{}, + databaseService: databases.GetDatabaseService(), + fieldEncryptor: encryption.GetFieldEncryptor(), + workspaceService: workspaces_services.GetWorkspaceService(), + backupRepository: backupRepository, + backupConfigService: backups_config.GetBackupConfigService(), + storageService: storages.GetStorageService(), + notificationSender: notifiers.GetNotifierService(), + backupCancelManager: taskCancelManager, + backupNodesRegistry: backupNodesRegistry, + logger: logger.GetLogger(), + createBackupUseCase: usecases.GetCreateBackupUsecase(), + nodeID: getNodeID(), + lastHeartbeat: time.Time{}, + runOnce: sync.Once{}, + hasRun: atomic.Bool{}, } var backupsScheduler = &BackupsScheduler{ @@ -59,6 +66,8 @@ var backupsScheduler = &BackupsScheduler{ logger: logger.GetLogger(), backupToNodeRelations: make(map[uuid.UUID]BackupToNodeRelation), backuperNode: backuperNode, + runOnce: sync.Once{}, + hasRun: atomic.Bool{}, } func GetBackupsScheduler() *BackupsScheduler { diff --git a/backend/internal/features/backups/backups/backuping/registry.go b/backend/internal/features/backups/backups/backuping/registry.go index 53365a5..69afe38 100644 --- a/backend/internal/features/backups/backups/backuping/registry.go +++ b/backend/internal/features/backups/backups/backuping/registry.go @@ -6,6 +6,8 @@ import ( "fmt" "log/slog" "strings" + "sync" + "sync/atomic" "time" cache_utils "databasus-backend/internal/util/cache" @@ -47,24 +49,37 @@ type BackupNodesRegistry struct { timeout time.Duration pubsubBackups *cache_utils.PubSubManager pubsubCompletions *cache_utils.PubSubManager + + runOnce sync.Once + hasRun atomic.Bool } func (r *BackupNodesRegistry) Run(ctx context.Context) { - if err := r.cleanupDeadNodes(); err != nil { - r.logger.Error("Failed to cleanup dead nodes on startup", "error", err) - } + wasAlreadyRun := r.hasRun.Load() - ticker := time.NewTicker(cleanupTickerInterval) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - if err := r.cleanupDeadNodes(); err != nil { - r.logger.Error("Failed to cleanup dead nodes", "error", err) + r.runOnce.Do(func() { + r.hasRun.Store(true) + + if err := r.cleanupDeadNodes(); err != nil { + r.logger.Error("Failed to cleanup dead nodes on startup", "error", err) + } + + ticker := time.NewTicker(cleanupTickerInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := r.cleanupDeadNodes(); err != nil { + r.logger.Error("Failed to cleanup dead nodes", "error", err) + } } } + }) + + if wasAlreadyRun { + panic(fmt.Sprintf("%T.Run() called multiple times", r)) } } diff --git a/backend/internal/features/backups/backups/backuping/registry_test.go b/backend/internal/features/backups/backups/backuping/registry_test.go index a2f382b..cefd995 100644 --- a/backend/internal/features/backups/backups/backuping/registry_test.go +++ b/backend/internal/features/backups/backups/backuping/registry_test.go @@ -4,6 +4,8 @@ import ( "context" "encoding/json" "fmt" + "sync" + "sync/atomic" "testing" "time" @@ -594,11 +596,13 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) { func createTestRegistry() *BackupNodesRegistry { return &BackupNodesRegistry{ - cache_utils.GetValkeyClient(), - logger.GetLogger(), - cache_utils.DefaultCacheTimeout, - cache_utils.NewPubSubManager(), - cache_utils.NewPubSubManager(), + client: cache_utils.GetValkeyClient(), + logger: logger.GetLogger(), + timeout: cache_utils.DefaultCacheTimeout, + pubsubBackups: cache_utils.NewPubSubManager(), + pubsubCompletions: cache_utils.NewPubSubManager(), + runOnce: sync.Once{}, + hasRun: atomic.Bool{}, } } diff --git a/backend/internal/features/backups/backups/backuping/scheduler.go b/backend/internal/features/backups/backups/backuping/scheduler.go index a910ba4..c67959a 100644 --- a/backend/internal/features/backups/backups/backuping/scheduler.go +++ b/backend/internal/features/backups/backups/backuping/scheduler.go @@ -2,6 +2,14 @@ package backuping import ( "context" + "fmt" + "log/slog" + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" + "databasus-backend/internal/config" backups_core "databasus-backend/internal/features/backups/backups/core" backups_config "databasus-backend/internal/features/backups/config" @@ -9,11 +17,6 @@ import ( task_cancellation "databasus-backend/internal/features/tasks/cancellation" "databasus-backend/internal/util/encryption" "databasus-backend/internal/util/period" - "fmt" - "log/slog" - "time" - - "github.com/google/uuid" ) const ( @@ -34,59 +37,72 @@ type BackupsScheduler struct { backupToNodeRelations map[uuid.UUID]BackupToNodeRelation backuperNode *BackuperNode + + runOnce sync.Once + hasRun atomic.Bool } func (s *BackupsScheduler) Run(ctx context.Context) { - s.lastBackupTime = time.Now().UTC() + wasAlreadyRun := s.hasRun.Load() - if config.GetEnv().IsManyNodesMode { - // wait other nodes to start - time.Sleep(schedulerStartupDelay) - } + s.runOnce.Do(func() { + s.hasRun.Store(true) - if err := s.failBackupsInProgress(); err != nil { - s.logger.Error("Failed to fail backups in progress", "error", err) - panic(err) - } + s.lastBackupTime = time.Now().UTC() - err := s.backupNodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted) - if err != nil { - s.logger.Error("Failed to subscribe to backup completions", "error", err) - panic(err) - } - - defer func() { - if err := s.backupNodesRegistry.UnsubscribeForBackupsCompletions(); err != nil { - s.logger.Error("Failed to unsubscribe from backup completions", "error", err) + if config.GetEnv().IsManyNodesMode { + // wait other nodes to start + time.Sleep(schedulerStartupDelay) } - }() - if ctx.Err() != nil { - return - } + if err := s.failBackupsInProgress(); err != nil { + s.logger.Error("Failed to fail backups in progress", "error", err) + panic(err) + } - ticker := time.NewTicker(schedulerTickerInterval) - defer ticker.Stop() + err := s.backupNodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted) + if err != nil { + s.logger.Error("Failed to subscribe to backup completions", "error", err) + panic(err) + } - for { - select { - case <-ctx.Done(): + defer func() { + if err := s.backupNodesRegistry.UnsubscribeForBackupsCompletions(); err != nil { + s.logger.Error("Failed to unsubscribe from backup completions", "error", err) + } + }() + + if ctx.Err() != nil { return - case <-ticker.C: - if err := s.cleanOldBackups(); err != nil { - s.logger.Error("Failed to clean old backups", "error", err) - } - - if err := s.checkDeadNodesAndFailBackups(); err != nil { - s.logger.Error("Failed to check dead nodes and fail backups", "error", err) - } - - if err := s.runPendingBackups(); err != nil { - s.logger.Error("Failed to run pending backups", "error", err) - } - - s.lastBackupTime = time.Now().UTC() } + + ticker := time.NewTicker(schedulerTickerInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := s.cleanOldBackups(); err != nil { + s.logger.Error("Failed to clean old backups", "error", err) + } + + if err := s.checkDeadNodesAndFailBackups(); err != nil { + s.logger.Error("Failed to check dead nodes and fail backups", "error", err) + } + + if err := s.runPendingBackups(); err != nil { + s.logger.Error("Failed to run pending backups", "error", err) + } + + s.lastBackupTime = time.Now().UTC() + } + } + }) + + if wasAlreadyRun { + panic(fmt.Sprintf("%T.Run() called multiple times", s)) } } diff --git a/backend/internal/features/backups/backups/backuping/scheduler_test.go b/backend/internal/features/backups/backups/backuping/scheduler_test.go index ac3fc8a..aef1442 100644 --- a/backend/internal/features/backups/backups/backuping/scheduler_test.go +++ b/backend/internal/features/backups/backups/backuping/scheduler_test.go @@ -835,7 +835,8 @@ func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T cache_utils.ClearAllCache() // Start scheduler so it can handle task completions - schedulerCancel := StartSchedulerForTest(t) + scheduler := CreateTestScheduler() + schedulerCancel := StartSchedulerForTest(t, scheduler) defer schedulerCancel() backuperNode := CreateTestBackuperNode() @@ -891,7 +892,7 @@ func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T t.Logf("Initial active tasks: %d", initialActiveTasks) // Start backup - GetBackupsScheduler().StartBackup(database.ID, false) + scheduler.StartBackup(database.ID, false) // Wait for backup to complete WaitForBackupCompletion(t, database.ID, 0, 10*time.Second) @@ -930,7 +931,8 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) { cache_utils.ClearAllCache() // Start scheduler so it can handle task completions - schedulerCancel := StartSchedulerForTest(t) + scheduler := CreateTestScheduler() + schedulerCancel := StartSchedulerForTest(t, scheduler) defer schedulerCancel() backuperNode := CreateTestBackuperNode() @@ -993,7 +995,7 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) { t.Logf("Initial active tasks: %d", initialActiveTasks) // Start backup - GetBackupsScheduler().StartBackup(database.ID, false) + scheduler.StartBackup(database.ID, false) // Wait for backup to fail WaitForBackupCompletion(t, database.ID, 0, 10*time.Second) diff --git a/backend/internal/features/backups/backups/backuping/testing.go b/backend/internal/features/backups/backups/backuping/testing.go index 747471a..fa06770 100644 --- a/backend/internal/features/backups/backups/backuping/testing.go +++ b/backend/internal/features/backups/backups/backuping/testing.go @@ -3,6 +3,8 @@ package backuping import ( "context" "fmt" + "sync" + "sync/atomic" "testing" "time" @@ -35,19 +37,37 @@ func CreateTestRouter() *gin.Engine { func CreateTestBackuperNode() *BackuperNode { return &BackuperNode{ - databases.GetDatabaseService(), - encryption.GetFieldEncryptor(), - workspaces_services.GetWorkspaceService(), + databaseService: databases.GetDatabaseService(), + fieldEncryptor: encryption.GetFieldEncryptor(), + workspaceService: workspaces_services.GetWorkspaceService(), + backupRepository: backupRepository, + backupConfigService: backups_config.GetBackupConfigService(), + storageService: storages.GetStorageService(), + notificationSender: notifiers.GetNotifierService(), + backupCancelManager: taskCancelManager, + backupNodesRegistry: backupNodesRegistry, + logger: logger.GetLogger(), + createBackupUseCase: usecases.GetCreateBackupUsecase(), + nodeID: uuid.New(), + lastHeartbeat: time.Time{}, + runOnce: sync.Once{}, + hasRun: atomic.Bool{}, + } +} + +func CreateTestScheduler() *BackupsScheduler { + return &BackupsScheduler{ backupRepository, backups_config.GetBackupConfigService(), storages.GetStorageService(), - notifiers.GetNotifierService(), taskCancelManager, backupNodesRegistry, + time.Now().UTC(), logger.GetLogger(), - usecases.GetCreateBackupUsecase(), - uuid.New(), - time.Time{}, + make(map[uuid.UUID]BackupToNodeRelation), + CreateTestBackuperNode(), + sync.Once{}, + atomic.Bool{}, } } @@ -141,13 +161,13 @@ func StartBackuperNodeForTest(t *testing.T, backuperNode *BackuperNode) context. // StartSchedulerForTest starts the BackupsScheduler in a goroutine for testing. // The scheduler subscribes to task completions and manages backup lifecycle. // Returns a context cancel function that should be deferred to stop the scheduler. -func StartSchedulerForTest(t *testing.T) context.CancelFunc { +func StartSchedulerForTest(t *testing.T, scheduler *BackupsScheduler) context.CancelFunc { ctx, cancel := context.WithCancel(context.Background()) done := make(chan struct{}) go func() { - GetBackupsScheduler().Run(ctx) + scheduler.Run(ctx) close(done) }() diff --git a/backend/internal/features/backups/backups/di.go b/backend/internal/features/backups/backups/di.go index ae053d3..cddb159 100644 --- a/backend/internal/features/backups/backups/di.go +++ b/backend/internal/features/backups/backups/di.go @@ -1,6 +1,9 @@ package backups import ( + "sync" + "sync/atomic" + audit_logs "databasus-backend/internal/features/audit_logs" "databasus-backend/internal/features/backups/backups/backuping" backups_core "databasus-backend/internal/features/backups/backups/core" @@ -52,11 +55,26 @@ func GetBackupController() *BackupController { return backupController } -func SetupDependencies() { - backups_config. - GetBackupConfigService(). - SetDatabaseStorageChangeListener(backupService) +var ( + setupOnce sync.Once + isSetup atomic.Bool +) - databases.GetDatabaseService().AddDbRemoveListener(backupService) - databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService()) +func SetupDependencies() { + wasAlreadySetup := isSetup.Load() + + setupOnce.Do(func() { + backups_config. + GetBackupConfigService(). + SetDatabaseStorageChangeListener(backupService) + + databases.GetDatabaseService().AddDbRemoveListener(backupService) + databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService()) + + isSetup.Store(true) + }) + + if wasAlreadySetup { + logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call") + } } diff --git a/backend/internal/features/backups/backups/download/background.go b/backend/internal/features/backups/backups/download/background.go index b919870..54e664a 100644 --- a/backend/internal/features/backups/backups/download/background.go +++ b/backend/internal/features/backups/backups/download/background.go @@ -2,33 +2,49 @@ package backups_download import ( "context" + "fmt" "log/slog" + "sync" + "sync/atomic" "time" ) type DownloadTokenBackgroundService struct { downloadTokenService *DownloadTokenService logger *slog.Logger + + runOnce sync.Once + hasRun atomic.Bool } func (s *DownloadTokenBackgroundService) Run(ctx context.Context) { - s.logger.Info("Starting download token cleanup background service") + wasAlreadyRun := s.hasRun.Load() - if ctx.Err() != nil { - return - } + s.runOnce.Do(func() { + s.hasRun.Store(true) - ticker := time.NewTicker(1 * time.Minute) - defer ticker.Stop() + s.logger.Info("Starting download token cleanup background service") - for { - select { - case <-ctx.Done(): + if ctx.Err() != nil { return - case <-ticker.C: - if err := s.downloadTokenService.CleanExpiredTokens(); err != nil { - s.logger.Error("Failed to clean expired download tokens", "error", err) + } + + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := s.downloadTokenService.CleanExpiredTokens(); err != nil { + s.logger.Error("Failed to clean expired download tokens", "error", err) + } } } + }) + + if wasAlreadyRun { + panic(fmt.Sprintf("%T.Run() called multiple times", s)) } } diff --git a/backend/internal/features/backups/backups/download/di.go b/backend/internal/features/backups/backups/download/di.go index 55e5121..5181ef1 100644 --- a/backend/internal/features/backups/backups/download/di.go +++ b/backend/internal/features/backups/backups/download/di.go @@ -1,6 +1,9 @@ package backups_download import ( + "sync" + "sync/atomic" + "databasus-backend/internal/config" cache_utils "databasus-backend/internal/util/cache" "databasus-backend/internal/util/logger" @@ -30,8 +33,10 @@ func init() { } downloadTokenBackgroundService = &DownloadTokenBackgroundService{ - downloadTokenService, - logger.GetLogger(), + downloadTokenService: downloadTokenService, + logger: logger.GetLogger(), + runOnce: sync.Once{}, + hasRun: atomic.Bool{}, } } diff --git a/backend/internal/features/backups/config/di.go b/backend/internal/features/backups/config/di.go index fc8a4cf..fe215dc 100644 --- a/backend/internal/features/backups/config/di.go +++ b/backend/internal/features/backups/config/di.go @@ -1,10 +1,14 @@ package backups_config import ( + "sync" + "sync/atomic" + "databasus-backend/internal/features/databases" "databasus-backend/internal/features/notifiers" "databasus-backend/internal/features/storages" workspaces_services "databasus-backend/internal/features/workspaces/services" + "databasus-backend/internal/util/logger" ) var backupConfigRepository = &BackupConfigRepository{} @@ -28,6 +32,21 @@ func GetBackupConfigService() *BackupConfigService { return backupConfigService } +var ( + setupOnce sync.Once + isSetup atomic.Bool +) + func SetupDependencies() { - storages.GetStorageService().SetStorageDatabaseCounter(backupConfigService) + wasAlreadySetup := isSetup.Load() + + setupOnce.Do(func() { + storages.GetStorageService().SetStorageDatabaseCounter(backupConfigService) + + isSetup.Store(true) + }) + + if wasAlreadySetup { + logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call") + } } diff --git a/backend/internal/features/databases/di.go b/backend/internal/features/databases/di.go index 522e502..870c3a6 100644 --- a/backend/internal/features/databases/di.go +++ b/backend/internal/features/databases/di.go @@ -1,6 +1,9 @@ package databases import ( + "sync" + "sync/atomic" + audit_logs "databasus-backend/internal/features/audit_logs" "databasus-backend/internal/features/notifiers" users_services "databasus-backend/internal/features/users/services" @@ -37,7 +40,22 @@ func GetDatabaseController() *DatabaseController { return databaseController } +var ( + setupOnce sync.Once + isSetup atomic.Bool +) + func SetupDependencies() { - workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(databaseService) - notifiers.GetNotifierService().SetNotifierDatabaseCounter(databaseService) + wasAlreadySetup := isSetup.Load() + + setupOnce.Do(func() { + workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(databaseService) + notifiers.GetNotifierService().SetNotifierDatabaseCounter(databaseService) + + isSetup.Store(true) + }) + + if wasAlreadySetup { + logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call") + } } diff --git a/backend/internal/features/healthcheck/attempt/background_service.go b/backend/internal/features/healthcheck/attempt/background_service.go index cbc68d1..dda0d10 100644 --- a/backend/internal/features/healthcheck/attempt/background_service.go +++ b/backend/internal/features/healthcheck/attempt/background_service.go @@ -2,30 +2,47 @@ package healthcheck_attempt import ( "context" - healthcheck_config "databasus-backend/internal/features/healthcheck/config" + "fmt" "log/slog" + "sync" + "sync/atomic" "time" + + healthcheck_config "databasus-backend/internal/features/healthcheck/config" ) type HealthcheckAttemptBackgroundService struct { healthcheckConfigService *healthcheck_config.HealthcheckConfigService checkDatabaseHealthUseCase *CheckDatabaseHealthUseCase logger *slog.Logger + + runOnce sync.Once + hasRun atomic.Bool } func (s *HealthcheckAttemptBackgroundService) Run(ctx context.Context) { - // first healthcheck immediately - s.checkDatabases() + wasAlreadyRun := s.hasRun.Load() - ticker := time.NewTicker(time.Minute) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - s.checkDatabases() + s.runOnce.Do(func() { + s.hasRun.Store(true) + + // first healthcheck immediately + s.checkDatabases() + + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + s.checkDatabases() + } } + }) + + if wasAlreadyRun { + panic(fmt.Sprintf("%T.Run() called multiple times", s)) } } diff --git a/backend/internal/features/healthcheck/attempt/di.go b/backend/internal/features/healthcheck/attempt/di.go index 7906905..aedfde1 100644 --- a/backend/internal/features/healthcheck/attempt/di.go +++ b/backend/internal/features/healthcheck/attempt/di.go @@ -1,6 +1,9 @@ package healthcheck_attempt import ( + "sync" + "sync/atomic" + "databasus-backend/internal/features/databases" healthcheck_config "databasus-backend/internal/features/healthcheck/config" "databasus-backend/internal/features/notifiers" @@ -22,9 +25,11 @@ var checkDatabaseHealthUseCase = &CheckDatabaseHealthUseCase{ } var healthcheckAttemptBackgroundService = &HealthcheckAttemptBackgroundService{ - healthcheck_config.GetHealthcheckConfigService(), - checkDatabaseHealthUseCase, - logger.GetLogger(), + healthcheckConfigService: healthcheck_config.GetHealthcheckConfigService(), + checkDatabaseHealthUseCase: checkDatabaseHealthUseCase, + logger: logger.GetLogger(), + runOnce: sync.Once{}, + hasRun: atomic.Bool{}, } var healthcheckAttemptController = &HealthcheckAttemptController{ healthcheckAttemptService, diff --git a/backend/internal/features/healthcheck/config/di.go b/backend/internal/features/healthcheck/config/di.go index b87b341..ce5744b 100644 --- a/backend/internal/features/healthcheck/config/di.go +++ b/backend/internal/features/healthcheck/config/di.go @@ -1,6 +1,9 @@ package healthcheck_config import ( + "sync" + "sync/atomic" + "databasus-backend/internal/features/audit_logs" "databasus-backend/internal/features/databases" workspaces_services "databasus-backend/internal/features/workspaces/services" @@ -27,8 +30,23 @@ func GetHealthcheckConfigController() *HealthcheckConfigController { return healthcheckConfigController } +var ( + setupOnce sync.Once + isSetup atomic.Bool +) + func SetupDependencies() { - databases. - GetDatabaseService(). - AddDbCreationListener(healthcheckConfigService) + wasAlreadySetup := isSetup.Load() + + setupOnce.Do(func() { + databases. + GetDatabaseService(). + AddDbCreationListener(healthcheckConfigService) + + isSetup.Store(true) + }) + + if wasAlreadySetup { + logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call") + } } diff --git a/backend/internal/features/notifiers/di.go b/backend/internal/features/notifiers/di.go index 410d943..71e63a4 100644 --- a/backend/internal/features/notifiers/di.go +++ b/backend/internal/features/notifiers/di.go @@ -1,6 +1,9 @@ package notifiers import ( + "sync" + "sync/atomic" + audit_logs "databasus-backend/internal/features/audit_logs" workspaces_services "databasus-backend/internal/features/workspaces/services" "databasus-backend/internal/util/encryption" @@ -32,6 +35,22 @@ func GetNotifierService() *NotifierService { func GetNotifierRepository() *NotifierRepository { return notifierRepository } + +var ( + setupOnce sync.Once + isSetup atomic.Bool +) + func SetupDependencies() { - workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(notifierService) + wasAlreadySetup := isSetup.Load() + + setupOnce.Do(func() { + workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(notifierService) + + isSetup.Store(true) + }) + + if wasAlreadySetup { + logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call") + } } diff --git a/backend/internal/features/restores/controller.go b/backend/internal/features/restores/controller.go index f74c3af..e3b2663 100644 --- a/backend/internal/features/restores/controller.go +++ b/backend/internal/features/restores/controller.go @@ -16,6 +16,7 @@ type RestoreController struct { func (c *RestoreController) RegisterRoutes(router *gin.RouterGroup) { router.GET("/restores/:backupId", c.GetRestores) router.POST("/restores/:backupId/restore", c.RestoreBackup) + router.POST("/restores/cancel/:restoreId", c.CancelRestore) } // GetRestores @@ -85,3 +86,33 @@ func (c *RestoreController) RestoreBackup(ctx *gin.Context) { ctx.JSON(http.StatusOK, gin.H{"message": "restore started successfully"}) } + +// CancelRestore +// @Summary Cancel an in-progress restore +// @Description Cancel a restore that is currently in progress +// @Tags restores +// @Param restoreId path string true "Restore ID" +// @Success 204 +// @Failure 400 +// @Failure 401 +// @Router /restores/cancel/{restoreId} [post] +func (c *RestoreController) CancelRestore(ctx *gin.Context) { + user, ok := users_middleware.GetUserFromContext(ctx) + if !ok { + ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"}) + return + } + + restoreID, err := uuid.Parse(ctx.Param("restoreId")) + if err != nil { + ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid restore ID"}) + return + } + + if err := c.restoreService.CancelRestore(user, restoreID); err != nil { + ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + ctx.Status(http.StatusNoContent) +} diff --git a/backend/internal/features/restores/controller_test.go b/backend/internal/features/restores/controller_test.go index 64b99a0..3e8a3ff 100644 --- a/backend/internal/features/restores/controller_test.go +++ b/backend/internal/features/restores/controller_test.go @@ -18,20 +18,25 @@ import ( "databasus-backend/internal/config" audit_logs "databasus-backend/internal/features/audit_logs" + "databasus-backend/internal/features/backups/backups" backups_core "databasus-backend/internal/features/backups/backups/core" backups_config "databasus-backend/internal/features/backups/config" "databasus-backend/internal/features/databases" "databasus-backend/internal/features/databases/databases/mysql" "databasus-backend/internal/features/databases/databases/postgresql" + "databasus-backend/internal/features/notifiers" restores_core "databasus-backend/internal/features/restores/core" + "databasus-backend/internal/features/restores/restoring" "databasus-backend/internal/features/storages" local_storage "databasus-backend/internal/features/storages/models/local" + tasks_cancellation "databasus-backend/internal/features/tasks/cancellation" users_dto "databasus-backend/internal/features/users/dto" users_enums "databasus-backend/internal/features/users/enums" users_services "databasus-backend/internal/features/users/services" users_testing "databasus-backend/internal/features/users/testing" workspaces_models "databasus-backend/internal/features/workspaces/models" workspaces_testing "databasus-backend/internal/features/workspaces/testing" + cache_utils "databasus-backend/internal/util/cache" util_encryption "databasus-backend/internal/util/encryption" test_utils "databasus-backend/internal/util/testing" "databasus-backend/internal/util/tools" @@ -370,6 +375,142 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) { } } +func Test_CancelRestore_InProgressRestore_SuccessfullyCancelled(t *testing.T) { + cache_utils.ClearAllCache() + tasks_cancellation.SetupDependencies() + + user := users_testing.CreateTestUser(users_enums.UserRoleAdmin) + router := createTestRouter() + workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router) + storage := storages.CreateTestStorage(workspace.ID) + notifier := notifiers.CreateTestNotifier(workspace.ID) + database := databases.CreateTestDatabase(workspace.ID, storage, notifier) + + defer func() { + backupRepo := backups_core.BackupRepository{} + backups, _ := backupRepo.FindByDatabaseID(database.ID) + for _, backup := range backups { + backupRepo.DeleteByID(backup.ID) + } + + restoreRepo := restores_core.RestoreRepository{} + restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress) + for _, restore := range restores { + restoreRepo.DeleteByID(restore.ID) + } + restores, _ = restoreRepo.FindByStatus(restores_core.RestoreStatusCanceled) + for _, restore := range restores { + restoreRepo.DeleteByID(restore.ID) + } + + databases.RemoveTestDatabase(database) + time.Sleep(50 * time.Millisecond) + storages.RemoveTestStorage(storage.ID) + notifiers.RemoveTestNotifier(notifier) + workspaces_testing.RemoveTestWorkspace(workspace, router) + + cache_utils.ClearAllCache() + }() + + backups_config.EnableBackupsForTestDatabase(database.ID, storage) + backup := backups.CreateTestBackup(database.ID, storage.ID) + + mockUsecase := &restoring.MockBlockingRestoreUsecase{ + StartedChan: make(chan bool, 1), + } + restorerNode := restoring.CreateTestRestorerNodeWithUsecase(mockUsecase) + + cancelNode := restoring.StartRestorerNodeForTest(t, restorerNode) + defer cancelNode() + + time.Sleep(200 * time.Millisecond) + + restoreRequest := restores_core.RestoreBackupRequest{ + PostgresqlDatabase: &postgresql.PostgresqlDatabase{ + Version: tools.PostgresqlVersion16, + Host: "localhost", + Port: 5432, + Username: "postgres", + Password: "postgres", + }, + } + + var restoreResponse map[string]interface{} + test_utils.MakePostRequestAndUnmarshal( + t, + router, + fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()), + "Bearer "+user.Token, + restoreRequest, + http.StatusOK, + &restoreResponse, + ) + + select { + case <-mockUsecase.StartedChan: + t.Log("Restore started and is blocking") + case <-time.After(2 * time.Second): + t.Fatal("Restore did not start within timeout") + } + + restoreRepo := &restores_core.RestoreRepository{} + restores, err := restoreRepo.FindByBackupID(backup.ID) + assert.NoError(t, err) + assert.Greater(t, len(restores), 0, "At least one restore should exist") + + var restoreID uuid.UUID + for _, r := range restores { + if r.Status == restores_core.RestoreStatusInProgress { + restoreID = r.ID + break + } + } + assert.NotEqual(t, uuid.Nil, restoreID, "Should find an in-progress restore") + + resp := test_utils.MakePostRequest( + t, + router, + fmt.Sprintf("/api/v1/restores/cancel/%s", restoreID.String()), + "Bearer "+user.Token, + nil, + http.StatusNoContent, + ) + + assert.Equal(t, http.StatusNoContent, resp.StatusCode) + + deadline := time.Now().UTC().Add(3 * time.Second) + var restore *restores_core.Restore + for time.Now().UTC().Before(deadline) { + restore, err = restoreRepo.FindByID(restoreID) + assert.NoError(t, err) + if restore.Status == restores_core.RestoreStatusCanceled { + break + } + time.Sleep(100 * time.Millisecond) + } + + assert.Equal(t, restores_core.RestoreStatusCanceled, restore.Status) + + auditLogService := audit_logs.GetAuditLogService() + auditLogs, err := auditLogService.GetWorkspaceAuditLogs( + workspace.ID, + &audit_logs.GetAuditLogsRequest{Limit: 100, Offset: 0}, + ) + assert.NoError(t, err) + + foundCancelLog := false + for _, log := range auditLogs.AuditLogs { + if strings.Contains(log.Message, "Restore cancelled") && + strings.Contains(log.Message, database.Name) { + foundCancelLog = true + break + } + } + assert.True(t, foundCancelLog, "Cancel audit log should be created") + + time.Sleep(200 * time.Millisecond) +} + func createTestRouter() *gin.Engine { return CreateTestRouter() } diff --git a/backend/internal/features/restores/core/enums.go b/backend/internal/features/restores/core/enums.go index db1c472..dc64bca 100644 --- a/backend/internal/features/restores/core/enums.go +++ b/backend/internal/features/restores/core/enums.go @@ -6,4 +6,5 @@ const ( RestoreStatusInProgress RestoreStatus = "IN_PROGRESS" RestoreStatusCompleted RestoreStatus = "COMPLETED" RestoreStatusFailed RestoreStatus = "FAILED" + RestoreStatusCanceled RestoreStatus = "CANCELED" ) diff --git a/backend/internal/features/restores/core/interfaces.go b/backend/internal/features/restores/core/interfaces.go index 0445727..a46ad77 100644 --- a/backend/internal/features/restores/core/interfaces.go +++ b/backend/internal/features/restores/core/interfaces.go @@ -1,6 +1,8 @@ package restores_core import ( + "context" + backups_core "databasus-backend/internal/features/backups/backups/core" backups_config "databasus-backend/internal/features/backups/config" "databasus-backend/internal/features/databases" @@ -9,6 +11,7 @@ import ( type RestoreBackupUsecase interface { Execute( + ctx context.Context, backupConfig *backups_config.BackupConfig, restore Restore, originalDB *databases.Database, diff --git a/backend/internal/features/restores/di.go b/backend/internal/features/restores/di.go index 2e1230b..8670583 100644 --- a/backend/internal/features/restores/di.go +++ b/backend/internal/features/restores/di.go @@ -1,6 +1,9 @@ package restores import ( + "sync" + "sync/atomic" + audit_logs "databasus-backend/internal/features/audit_logs" "databasus-backend/internal/features/backups/backups" backups_config "databasus-backend/internal/features/backups/config" @@ -9,6 +12,7 @@ import ( restores_core "databasus-backend/internal/features/restores/core" "databasus-backend/internal/features/restores/usecases" "databasus-backend/internal/features/storages" + tasks_cancellation "databasus-backend/internal/features/tasks/cancellation" workspaces_services "databasus-backend/internal/features/workspaces/services" "databasus-backend/internal/util/encryption" "databasus-backend/internal/util/logger" @@ -27,6 +31,7 @@ var restoreService = &RestoreService{ audit_logs.GetAuditLogService(), encryption.GetFieldEncryptor(), disk.GetDiskService(), + tasks_cancellation.GetTaskCancelManager(), } var restoreController = &RestoreController{ restoreService, @@ -36,6 +41,21 @@ func GetRestoreController() *RestoreController { return restoreController } +var ( + setupOnce sync.Once + isSetup atomic.Bool +) + func SetupDependencies() { - backups.GetBackupService().AddBackupRemoveListener(restoreService) + wasAlreadySetup := isSetup.Load() + + setupOnce.Do(func() { + backups.GetBackupService().AddBackupRemoveListener(restoreService) + + isSetup.Store(true) + }) + + if wasAlreadySetup { + logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call") + } } diff --git a/backend/internal/features/restores/restoring/di.go b/backend/internal/features/restores/restoring/di.go index ca76927..29b850d 100644 --- a/backend/internal/features/restores/restoring/di.go +++ b/backend/internal/features/restores/restoring/di.go @@ -1,6 +1,8 @@ package restoring import ( + "sync" + "sync/atomic" "time" "github.com/google/uuid" @@ -11,6 +13,7 @@ import ( restores_core "databasus-backend/internal/features/restores/core" "databasus-backend/internal/features/restores/usecases" "databasus-backend/internal/features/storages" + tasks_cancellation "databasus-backend/internal/features/tasks/cancellation" cache_utils "databasus-backend/internal/util/cache" "databasus-backend/internal/util/encryption" "databasus-backend/internal/util/logger" @@ -19,11 +22,13 @@ import ( var restoreRepository = &restores_core.RestoreRepository{} var restoreNodesRegistry = &RestoreNodesRegistry{ - cache_utils.GetValkeyClient(), - logger.GetLogger(), - cache_utils.DefaultCacheTimeout, - cache_utils.NewPubSubManager(), - cache_utils.NewPubSubManager(), + client: cache_utils.GetValkeyClient(), + logger: logger.GetLogger(), + timeout: cache_utils.DefaultCacheTimeout, + pubsubRestores: cache_utils.NewPubSubManager(), + pubsubCompletions: cache_utils.NewPubSubManager(), + runOnce: sync.Once{}, + hasRun: atomic.Bool{}, } var restoreDatabaseCache = cache_utils.NewCacheUtil[RestoreDatabaseCache]( @@ -31,19 +36,24 @@ var restoreDatabaseCache = cache_utils.NewCacheUtil[RestoreDatabaseCache]( "restore_db:", ) +var restoreCancelManager = tasks_cancellation.GetTaskCancelManager() + var restorerNode = &RestorerNode{ - uuid.New(), - databases.GetDatabaseService(), - backups.GetBackupService(), - encryption.GetFieldEncryptor(), - restoreRepository, - backups_config.GetBackupConfigService(), - storages.GetStorageService(), - restoreNodesRegistry, - logger.GetLogger(), - usecases.GetRestoreBackupUsecase(), - restoreDatabaseCache, - time.Time{}, + nodeID: uuid.New(), + databaseService: databases.GetDatabaseService(), + backupService: backups.GetBackupService(), + fieldEncryptor: encryption.GetFieldEncryptor(), + restoreRepository: restoreRepository, + backupConfigService: backups_config.GetBackupConfigService(), + storageService: storages.GetStorageService(), + restoreNodesRegistry: restoreNodesRegistry, + logger: logger.GetLogger(), + restoreBackupUsecase: usecases.GetRestoreBackupUsecase(), + cacheUtil: restoreDatabaseCache, + restoreCancelManager: restoreCancelManager, + lastHeartbeat: time.Time{}, + runOnce: sync.Once{}, + hasRun: atomic.Bool{}, } var restoresScheduler = &RestoresScheduler{ @@ -58,6 +68,8 @@ var restoresScheduler = &RestoresScheduler{ restorerNode: restorerNode, cacheUtil: restoreDatabaseCache, completionSubscriptionID: uuid.Nil, + runOnce: sync.Once{}, + hasRun: atomic.Bool{}, } func GetRestoresScheduler() *RestoresScheduler { diff --git a/backend/internal/features/restores/restoring/mocks.go b/backend/internal/features/restores/restoring/mocks.go index df02d75..7fe39c0 100644 --- a/backend/internal/features/restores/restoring/mocks.go +++ b/backend/internal/features/restores/restoring/mocks.go @@ -1,6 +1,7 @@ package restoring import ( + "context" "errors" backups_core "databasus-backend/internal/features/backups/backups/core" @@ -13,6 +14,7 @@ import ( type MockSuccessRestoreUsecase struct{} func (uc *MockSuccessRestoreUsecase) Execute( + ctx context.Context, backupConfig *backups_config.BackupConfig, restore restores_core.Restore, originalDB *databases.Database, @@ -27,6 +29,7 @@ func (uc *MockSuccessRestoreUsecase) Execute( type MockFailedRestoreUsecase struct{} func (uc *MockFailedRestoreUsecase) Execute( + ctx context.Context, backupConfig *backups_config.BackupConfig, restore restores_core.Restore, originalDB *databases.Database, @@ -44,6 +47,7 @@ type MockCaptureCredentialsRestoreUsecase struct { } func (uc *MockCaptureCredentialsRestoreUsecase) Execute( + ctx context.Context, backupConfig *backups_config.BackupConfig, restore restores_core.Restore, originalDB *databases.Database, @@ -59,3 +63,26 @@ func (uc *MockCaptureCredentialsRestoreUsecase) Execute( } return errors.New("mock restore failed") } + +type MockBlockingRestoreUsecase struct { + StartedChan chan bool +} + +func (uc *MockBlockingRestoreUsecase) Execute( + ctx context.Context, + backupConfig *backups_config.BackupConfig, + restore restores_core.Restore, + originalDB *databases.Database, + restoringToDB *databases.Database, + backup *backups_core.Backup, + storage *storages.Storage, + isExcludeExtensions bool, +) error { + if uc.StartedChan != nil { + uc.StartedChan <- true + } + + <-ctx.Done() + + return ctx.Err() +} diff --git a/backend/internal/features/restores/restoring/registry.go b/backend/internal/features/restores/restoring/registry.go index afbe9fc..01b72bd 100644 --- a/backend/internal/features/restores/restoring/registry.go +++ b/backend/internal/features/restores/restoring/registry.go @@ -6,6 +6,8 @@ import ( "fmt" "log/slog" "strings" + "sync" + "sync/atomic" "time" cache_utils "databasus-backend/internal/util/cache" @@ -47,24 +49,37 @@ type RestoreNodesRegistry struct { timeout time.Duration pubsubRestores *cache_utils.PubSubManager pubsubCompletions *cache_utils.PubSubManager + + runOnce sync.Once + hasRun atomic.Bool } func (r *RestoreNodesRegistry) Run(ctx context.Context) { - if err := r.cleanupDeadNodes(); err != nil { - r.logger.Error("Failed to cleanup dead nodes on startup", "error", err) - } + wasAlreadyRun := r.hasRun.Load() - ticker := time.NewTicker(cleanupTickerInterval) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - if err := r.cleanupDeadNodes(); err != nil { - r.logger.Error("Failed to cleanup dead nodes", "error", err) + r.runOnce.Do(func() { + r.hasRun.Store(true) + + if err := r.cleanupDeadNodes(); err != nil { + r.logger.Error("Failed to cleanup dead nodes on startup", "error", err) + } + + ticker := time.NewTicker(cleanupTickerInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := r.cleanupDeadNodes(); err != nil { + r.logger.Error("Failed to cleanup dead nodes", "error", err) + } } } + }) + + if wasAlreadyRun { + panic(fmt.Sprintf("%T.Run() called multiple times", r)) } } diff --git a/backend/internal/features/restores/restoring/registry_test.go b/backend/internal/features/restores/restoring/registry_test.go index fd6169c..2b7e87f 100644 --- a/backend/internal/features/restores/restoring/registry_test.go +++ b/backend/internal/features/restores/restoring/registry_test.go @@ -4,6 +4,8 @@ import ( "context" "encoding/json" "fmt" + "sync" + "sync/atomic" "testing" "time" @@ -594,11 +596,13 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) { func createTestRegistry() *RestoreNodesRegistry { return &RestoreNodesRegistry{ - cache_utils.GetValkeyClient(), - logger.GetLogger(), - cache_utils.DefaultCacheTimeout, - cache_utils.NewPubSubManager(), - cache_utils.NewPubSubManager(), + client: cache_utils.GetValkeyClient(), + logger: logger.GetLogger(), + timeout: cache_utils.DefaultCacheTimeout, + pubsubRestores: cache_utils.NewPubSubManager(), + pubsubCompletions: cache_utils.NewPubSubManager(), + runOnce: sync.Once{}, + hasRun: atomic.Bool{}, } } diff --git a/backend/internal/features/restores/restoring/restorer.go b/backend/internal/features/restores/restoring/restorer.go index ba9a059..2947aa7 100644 --- a/backend/internal/features/restores/restoring/restorer.go +++ b/backend/internal/features/restores/restoring/restorer.go @@ -2,8 +2,12 @@ package restoring import ( "context" + "errors" "fmt" "log/slog" + "strings" + "sync" + "sync/atomic" "time" "github.com/google/uuid" @@ -14,6 +18,7 @@ import ( "databasus-backend/internal/features/databases" restores_core "databasus-backend/internal/features/restores/core" "databasus-backend/internal/features/storages" + tasks_cancellation "databasus-backend/internal/features/tasks/cancellation" cache_utils "databasus-backend/internal/util/cache" util_encryption "databasus-backend/internal/util/encryption" ) @@ -36,70 +41,84 @@ type RestorerNode struct { logger *slog.Logger restoreBackupUsecase restores_core.RestoreBackupUsecase cacheUtil *cache_utils.CacheUtil[RestoreDatabaseCache] + restoreCancelManager *tasks_cancellation.TaskCancelManager lastHeartbeat time.Time + + runOnce sync.Once + hasRun atomic.Bool } func (n *RestorerNode) Run(ctx context.Context) { - n.lastHeartbeat = time.Now().UTC() + wasAlreadyRun := n.hasRun.Load() - throughputMBs := config.GetEnv().NodeNetworkThroughputMBs + n.runOnce.Do(func() { + n.hasRun.Store(true) - restoreNode := RestoreNode{ - ID: n.nodeID, - ThroughputMBs: throughputMBs, - } + n.lastHeartbeat = time.Now().UTC() - if err := n.restoreNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), restoreNode); err != nil { - n.logger.Error("Failed to register node in registry", "error", err) - panic(err) - } + throughputMBs := config.GetEnv().NodeNetworkThroughputMBs - restoreHandler := func(restoreID uuid.UUID, isCallNotifier bool) { - n.MakeRestore(restoreID) - if err := n.restoreNodesRegistry.PublishRestoreCompletion(n.nodeID, restoreID); err != nil { - n.logger.Error( - "Failed to publish restore completion", - "error", - err, - "restoreID", - restoreID, - ) + restoreNode := RestoreNode{ + ID: n.nodeID, + ThroughputMBs: throughputMBs, } - } - err := n.restoreNodesRegistry.SubscribeNodeForRestoresAssignment( - n.nodeID, - restoreHandler, - ) - if err != nil { - n.logger.Error("Failed to subscribe to restore assignments", "error", err) - panic(err) - } - defer func() { - if err := n.restoreNodesRegistry.UnsubscribeNodeForRestoresAssignments(); err != nil { - n.logger.Error("Failed to unsubscribe from restore assignments", "error", err) + if err := n.restoreNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), restoreNode); err != nil { + n.logger.Error("Failed to register node in registry", "error", err) + panic(err) } - }() - ticker := time.NewTicker(heartbeatTickerInterval) - defer ticker.Stop() - - n.logger.Info("Restore node started", "nodeID", n.nodeID, "throughput", throughputMBs) - - for { - select { - case <-ctx.Done(): - n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID) - - if err := n.restoreNodesRegistry.UnregisterNodeFromRegistry(restoreNode); err != nil { - n.logger.Error("Failed to unregister node from registry", "error", err) + restoreHandler := func(restoreID uuid.UUID, isCallNotifier bool) { + n.MakeRestore(restoreID) + if err := n.restoreNodesRegistry.PublishRestoreCompletion(n.nodeID, restoreID); err != nil { + n.logger.Error( + "Failed to publish restore completion", + "error", + err, + "restoreID", + restoreID, + ) } - - return - case <-ticker.C: - n.sendHeartbeat(&restoreNode) } + + err := n.restoreNodesRegistry.SubscribeNodeForRestoresAssignment( + n.nodeID, + restoreHandler, + ) + if err != nil { + n.logger.Error("Failed to subscribe to restore assignments", "error", err) + panic(err) + } + defer func() { + if err := n.restoreNodesRegistry.UnsubscribeNodeForRestoresAssignments(); err != nil { + n.logger.Error("Failed to unsubscribe from restore assignments", "error", err) + } + }() + + ticker := time.NewTicker(heartbeatTickerInterval) + defer ticker.Stop() + + n.logger.Info("Restore node started", "nodeID", n.nodeID, "throughput", throughputMBs) + + for { + select { + case <-ctx.Done(): + n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID) + + if err := n.restoreNodesRegistry.UnregisterNodeFromRegistry(restoreNode); err != nil { + n.logger.Error("Failed to unregister node from registry", "error", err) + } + + return + case <-ticker.C: + n.sendHeartbeat(&restoreNode) + } + } + }) + + if wasAlreadyRun { + panic(fmt.Sprintf("%T.Run() called multiple times", n)) } } @@ -176,6 +195,11 @@ func (n *RestorerNode) MakeRestore(restoreID uuid.UUID) { start := time.Now().UTC() + // Create cancellable context + ctx, cancel := context.WithCancel(context.Background()) + n.restoreCancelManager.RegisterTask(restore.ID, cancel) + defer n.restoreCancelManager.UnregisterTask(restore.ID) + // Create restoring database from cached credentials restoringToDB := &databases.Database{ Type: database.Type, @@ -204,6 +228,7 @@ func (n *RestorerNode) MakeRestore(restoreID uuid.UUID) { } err = n.restoreBackupUsecase.Execute( + ctx, backupConfig, *restore, database, @@ -216,6 +241,29 @@ func (n *RestorerNode) MakeRestore(restoreID uuid.UUID) { if err != nil { errMsg := err.Error() + // Check if restore was cancelled + isCancelled := strings.Contains(errMsg, "restore cancelled") || + strings.Contains(errMsg, "context canceled") || + errors.Is(err, context.Canceled) + isShutdown := strings.Contains(errMsg, "shutdown") + + if isCancelled && !isShutdown { + n.logger.Warn("Restore was cancelled by user or system", + "restoreId", restore.ID, + "isCancelled", isCancelled, + "isShutdown", isShutdown, + ) + + restore.Status = restores_core.RestoreStatusCanceled + restore.RestoreDurationMs = time.Since(start).Milliseconds() + + if err := n.restoreRepository.Save(restore); err != nil { + n.logger.Error("Failed to save cancelled restore", "error", err) + } + + return + } + n.logger.Error("Restore execution failed", "restoreId", restore.ID, "backupId", backup.ID, diff --git a/backend/internal/features/restores/restoring/scheduler.go b/backend/internal/features/restores/restoring/scheduler.go index 93668ba..361b3a1 100644 --- a/backend/internal/features/restores/restoring/scheduler.go +++ b/backend/internal/features/restores/restoring/scheduler.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "log/slog" + "sync" + "sync/atomic" "time" "github.com/google/uuid" @@ -34,51 +36,64 @@ type RestoresScheduler struct { restorerNode *RestorerNode cacheUtil *cache_utils.CacheUtil[RestoreDatabaseCache] completionSubscriptionID uuid.UUID + + runOnce sync.Once + hasRun atomic.Bool } func (s *RestoresScheduler) Run(ctx context.Context) { - s.lastCheckTime = time.Now().UTC() + wasAlreadyRun := s.hasRun.Load() - if config.GetEnv().IsManyNodesMode { - // wait other nodes to start - time.Sleep(schedulerStartupDelay) - } + s.runOnce.Do(func() { + s.hasRun.Store(true) - if err := s.failRestoresInProgress(); err != nil { - s.logger.Error("Failed to fail restores in progress", "error", err) - panic(err) - } + s.lastCheckTime = time.Now().UTC() - err := s.restoreNodesRegistry.SubscribeForRestoresCompletions(s.onRestoreCompleted) - if err != nil { - s.logger.Error("Failed to subscribe to restore completions", "error", err) - panic(err) - } - - defer func() { - if err := s.restoreNodesRegistry.UnsubscribeForRestoresCompletions(); err != nil { - s.logger.Error("Failed to unsubscribe from restore completions", "error", err) + if config.GetEnv().IsManyNodesMode { + // wait other nodes to start + time.Sleep(schedulerStartupDelay) } - }() - if ctx.Err() != nil { - return - } + if err := s.failRestoresInProgress(); err != nil { + s.logger.Error("Failed to fail restores in progress", "error", err) + panic(err) + } - ticker := time.NewTicker(schedulerTickerInterval) - defer ticker.Stop() + err := s.restoreNodesRegistry.SubscribeForRestoresCompletions(s.onRestoreCompleted) + if err != nil { + s.logger.Error("Failed to subscribe to restore completions", "error", err) + panic(err) + } - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - if err := s.checkDeadNodesAndFailRestores(); err != nil { - s.logger.Error("Failed to check dead nodes and fail restores", "error", err) + defer func() { + if err := s.restoreNodesRegistry.UnsubscribeForRestoresCompletions(); err != nil { + s.logger.Error("Failed to unsubscribe from restore completions", "error", err) } + }() - s.lastCheckTime = time.Now().UTC() + if ctx.Err() != nil { + return } + + ticker := time.NewTicker(schedulerTickerInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := s.checkDeadNodesAndFailRestores(); err != nil { + s.logger.Error("Failed to check dead nodes and fail restores", "error", err) + } + + s.lastCheckTime = time.Now().UTC() + } + } + }) + + if wasAlreadyRun { + panic(fmt.Sprintf("%T.Run() called multiple times", s)) } } diff --git a/backend/internal/features/restores/restoring/scheduler_test.go b/backend/internal/features/restores/restoring/scheduler_test.go index 3f39a3f..874e4b7 100644 --- a/backend/internal/features/restores/restoring/scheduler_test.go +++ b/backend/internal/features/restores/restoring/scheduler_test.go @@ -424,7 +424,8 @@ func Test_StartRestore_RestoreCompletes_DecrementsActiveTaskCount(t *testing.T) cache_utils.ClearAllCache() // Start scheduler so it can handle task completions - schedulerCancel := StartSchedulerForTest(t) + scheduler := CreateTestRestoresScheduler() + schedulerCancel := StartSchedulerForTest(t, scheduler) defer schedulerCancel() restorerNode := CreateTestRestorerNode() @@ -485,7 +486,7 @@ func Test_StartRestore_RestoreCompletes_DecrementsActiveTaskCount(t *testing.T) err = restoreRepository.Save(restore) assert.NoError(t, err) - err = GetRestoresScheduler().StartRestore(restore.ID, nil) + err = scheduler.StartRestore(restore.ID, nil) assert.NoError(t, err) // Wait for restore to complete @@ -524,7 +525,8 @@ func Test_StartRestore_RestoreFails_DecrementsActiveTaskCount(t *testing.T) { cache_utils.ClearAllCache() // Start scheduler so it can handle task completions - schedulerCancel := StartSchedulerForTest(t) + scheduler := CreateTestRestoresScheduler() + schedulerCancel := StartSchedulerForTest(t, scheduler) defer schedulerCancel() restorerNode := CreateTestRestorerNode() @@ -585,7 +587,7 @@ func Test_StartRestore_RestoreFails_DecrementsActiveTaskCount(t *testing.T) { err = restoreRepository.Save(restore) assert.NoError(t, err) - err = GetRestoresScheduler().StartRestore(restore.ID, nil) + err = scheduler.StartRestore(restore.ID, nil) assert.NoError(t, err) // Wait for restore to fail @@ -729,7 +731,8 @@ func Test_StartRestore_CredentialsRemovedAfterRestoreStarts(t *testing.T) { cache_utils.ClearAllCache() // Start scheduler so it can handle task assignments - schedulerCancel := StartSchedulerForTest(t) + scheduler := CreateTestRestoresScheduler() + schedulerCancel := StartSchedulerForTest(t, scheduler) defer schedulerCancel() // Create mock restorer node with credential capture usecase @@ -810,7 +813,7 @@ func Test_StartRestore_CredentialsRemovedAfterRestoreStarts(t *testing.T) { } // Call StartRestore to cache credentials and trigger restore - err = GetRestoresScheduler().StartRestore(restore.ID, dbCache) + err = scheduler.StartRestore(restore.ID, dbCache) assert.NoError(t, err) // Wait for mock usecase to be called (with timeout) diff --git a/backend/internal/features/restores/restoring/testing.go b/backend/internal/features/restores/restoring/testing.go index e202295..dc9a059 100644 --- a/backend/internal/features/restores/restoring/testing.go +++ b/backend/internal/features/restores/restoring/testing.go @@ -3,6 +3,8 @@ package restoring import ( "context" "fmt" + "sync" + "sync/atomic" "testing" "time" @@ -17,6 +19,7 @@ import ( restores_core "databasus-backend/internal/features/restores/core" "databasus-backend/internal/features/restores/usecases" "databasus-backend/internal/features/storages" + tasks_cancellation "databasus-backend/internal/features/tasks/cancellation" workspaces_controllers "databasus-backend/internal/features/workspaces/controllers" workspaces_testing "databasus-backend/internal/features/workspaces/testing" "databasus-backend/internal/util/encryption" @@ -36,18 +39,59 @@ func CreateTestRouter() *gin.Engine { func CreateTestRestorerNode() *RestorerNode { return &RestorerNode{ - uuid.New(), - databases.GetDatabaseService(), - backups.GetBackupService(), - encryption.GetFieldEncryptor(), + nodeID: uuid.New(), + databaseService: databases.GetDatabaseService(), + backupService: backups.GetBackupService(), + fieldEncryptor: encryption.GetFieldEncryptor(), + restoreRepository: restoreRepository, + backupConfigService: backups_config.GetBackupConfigService(), + storageService: storages.GetStorageService(), + restoreNodesRegistry: restoreNodesRegistry, + logger: logger.GetLogger(), + restoreBackupUsecase: usecases.GetRestoreBackupUsecase(), + cacheUtil: restoreDatabaseCache, + restoreCancelManager: tasks_cancellation.GetTaskCancelManager(), + lastHeartbeat: time.Time{}, + runOnce: sync.Once{}, + hasRun: atomic.Bool{}, + } +} + +func CreateTestRestorerNodeWithUsecase(usecase restores_core.RestoreBackupUsecase) *RestorerNode { + return &RestorerNode{ + nodeID: uuid.New(), + databaseService: databases.GetDatabaseService(), + backupService: backups.GetBackupService(), + fieldEncryptor: encryption.GetFieldEncryptor(), + restoreRepository: restoreRepository, + backupConfigService: backups_config.GetBackupConfigService(), + storageService: storages.GetStorageService(), + restoreNodesRegistry: restoreNodesRegistry, + logger: logger.GetLogger(), + restoreBackupUsecase: usecase, + cacheUtil: restoreDatabaseCache, + restoreCancelManager: tasks_cancellation.GetTaskCancelManager(), + lastHeartbeat: time.Time{}, + runOnce: sync.Once{}, + hasRun: atomic.Bool{}, + } +} + +func CreateTestRestoresScheduler() *RestoresScheduler { + return &RestoresScheduler{ restoreRepository, - backups_config.GetBackupConfigService(), + backups.GetBackupService(), storages.GetStorageService(), + backups_config.GetBackupConfigService(), restoreNodesRegistry, + time.Now().UTC(), logger.GetLogger(), - usecases.GetRestoreBackupUsecase(), + make(map[uuid.UUID]RestoreToNodeRelation), + restorerNode, restoreDatabaseCache, - time.Time{}, + uuid.Nil, + sync.Once{}, + atomic.Bool{}, } } @@ -128,13 +172,13 @@ func StartRestorerNodeForTest(t *testing.T, restorerNode *RestorerNode) context. // StartSchedulerForTest starts the RestoresScheduler in a goroutine for testing. // The scheduler subscribes to task completions and manages restore lifecycle. // Returns a context cancel function that should be deferred to stop the scheduler. -func StartSchedulerForTest(t *testing.T) context.CancelFunc { +func StartSchedulerForTest(t *testing.T, scheduler *RestoresScheduler) context.CancelFunc { ctx, cancel := context.WithCancel(context.Background()) done := make(chan struct{}) go func() { - GetRestoresScheduler().Run(ctx) + scheduler.Run(ctx) close(done) }() diff --git a/backend/internal/features/restores/service.go b/backend/internal/features/restores/service.go index 564b9f0..812a6f2 100644 --- a/backend/internal/features/restores/service.go +++ b/backend/internal/features/restores/service.go @@ -11,6 +11,7 @@ import ( "databasus-backend/internal/features/restores/restoring" "databasus-backend/internal/features/restores/usecases" "databasus-backend/internal/features/storages" + tasks_cancellation "databasus-backend/internal/features/tasks/cancellation" users_models "databasus-backend/internal/features/users/models" workspaces_services "databasus-backend/internal/features/workspaces/services" "databasus-backend/internal/util/encryption" @@ -35,6 +36,7 @@ type RestoreService struct { auditLogService *audit_logs.AuditLogService fieldEncryptor encryption.FieldEncryptor diskService *disk.DiskService + taskCancelManager *tasks_cancellation.TaskCancelManager } func (s *RestoreService) OnBeforeBackupRemove(backup *backups_core.Backup) error { @@ -340,3 +342,55 @@ func (s *RestoreService) validateDiskSpace( return nil } + +func (s *RestoreService) CancelRestore( + user *users_models.User, + restoreID uuid.UUID, +) error { + restore, err := s.restoreRepository.FindByID(restoreID) + if err != nil { + return err + } + + backup, err := s.backupService.GetBackup(restore.BackupID) + if err != nil { + return err + } + + database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID) + if err != nil { + return err + } + + if database.WorkspaceID == nil { + return errors.New("cannot cancel restore for database without workspace") + } + + canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user) + if err != nil { + return err + } + if !canManage { + return errors.New("insufficient permissions to cancel restore for this database") + } + + if restore.Status != restores_core.RestoreStatusInProgress { + return errors.New("restore is not in progress") + } + + if err := s.taskCancelManager.CancelTask(restoreID); err != nil { + return err + } + + s.auditLogService.WriteAuditLog( + fmt.Sprintf( + "Restore cancelled for database: %s (ID: %s)", + database.Name, + restoreID.String(), + ), + &user.ID, + database.WorkspaceID, + ) + + return nil +} diff --git a/backend/internal/features/restores/usecases/mariadb/restore_backup_uc.go b/backend/internal/features/restores/usecases/mariadb/restore_backup_uc.go index a7feb42..eaf14a1 100644 --- a/backend/internal/features/restores/usecases/mariadb/restore_backup_uc.go +++ b/backend/internal/features/restores/usecases/mariadb/restore_backup_uc.go @@ -36,6 +36,7 @@ type RestoreMariadbBackupUsecase struct { } func (uc *RestoreMariadbBackupUsecase) Execute( + parentCtx context.Context, originalDB *databases.Database, restoringToDB *databases.Database, backupConfig *backups_config.BackupConfig, @@ -79,6 +80,7 @@ func (uc *RestoreMariadbBackupUsecase) Execute( } return uc.restoreFromStorage( + parentCtx, originalDB, tools.GetMariadbExecutable( tools.MariadbExecutableMariadb, @@ -95,6 +97,7 @@ func (uc *RestoreMariadbBackupUsecase) Execute( } func (uc *RestoreMariadbBackupUsecase) restoreFromStorage( + parentCtx context.Context, database *databases.Database, mariadbBin string, args []string, @@ -103,7 +106,7 @@ func (uc *RestoreMariadbBackupUsecase) restoreFromStorage( storage *storages.Storage, mdbConfig *mariadbtypes.MariadbDatabase, ) error { - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute) + ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute) defer cancel() go func() { @@ -113,6 +116,9 @@ func (uc *RestoreMariadbBackupUsecase) restoreFromStorage( select { case <-ctx.Done(): return + case <-parentCtx.Done(): + cancel() + return case <-ticker.C: if config.IsShouldShutdown() { cancel() @@ -213,6 +219,15 @@ func (uc *RestoreMariadbBackupUsecase) executeMariadbRestore( waitErr := cmd.Wait() stderrOutput := <-stderrCh + // Check for cancellation + select { + case <-ctx.Done(): + if errors.Is(ctx.Err(), context.Canceled) { + return fmt.Errorf("restore cancelled") + } + default: + } + if config.IsShouldShutdown() { return fmt.Errorf("restore cancelled due to shutdown") } diff --git a/backend/internal/features/restores/usecases/mongodb/restore_backup_uc.go b/backend/internal/features/restores/usecases/mongodb/restore_backup_uc.go index c4551c9..a330682 100644 --- a/backend/internal/features/restores/usecases/mongodb/restore_backup_uc.go +++ b/backend/internal/features/restores/usecases/mongodb/restore_backup_uc.go @@ -36,6 +36,7 @@ type RestoreMongodbBackupUsecase struct { } func (uc *RestoreMongodbBackupUsecase) Execute( + parentCtx context.Context, originalDB *databases.Database, restoringToDB *databases.Database, backupConfig *backups_config.BackupConfig, @@ -76,6 +77,7 @@ func (uc *RestoreMongodbBackupUsecase) Execute( args := uc.buildMongorestoreArgs(mdb, decryptedPassword, sourceDatabase) return uc.restoreFromStorage( + parentCtx, tools.GetMongodbExecutable( tools.MongodbExecutableMongorestore, config.GetEnv().EnvMode, @@ -122,12 +124,13 @@ func (uc *RestoreMongodbBackupUsecase) buildMongorestoreArgs( } func (uc *RestoreMongodbBackupUsecase) restoreFromStorage( + parentCtx context.Context, mongorestoreBin string, args []string, backup *backups_core.Backup, storage *storages.Storage, ) error { - ctx, cancel := context.WithTimeout(context.Background(), restoreTimeout) + ctx, cancel := context.WithTimeout(parentCtx, restoreTimeout) defer cancel() go func() { @@ -137,6 +140,9 @@ func (uc *RestoreMongodbBackupUsecase) restoreFromStorage( select { case <-ctx.Done(): return + case <-parentCtx.Done(): + cancel() + return case <-ticker.C: if config.IsShouldShutdown() { cancel() @@ -218,6 +224,15 @@ func (uc *RestoreMongodbBackupUsecase) executeMongoRestore( waitErr := cmd.Wait() stderrOutput := <-stderrCh + // Check for cancellation + select { + case <-ctx.Done(): + if errors.Is(ctx.Err(), context.Canceled) { + return fmt.Errorf("restore cancelled") + } + default: + } + if config.IsShouldShutdown() { return fmt.Errorf("restore cancelled due to shutdown") } diff --git a/backend/internal/features/restores/usecases/mysql/restore_backup_uc.go b/backend/internal/features/restores/usecases/mysql/restore_backup_uc.go index ab182d3..c33ac3c 100644 --- a/backend/internal/features/restores/usecases/mysql/restore_backup_uc.go +++ b/backend/internal/features/restores/usecases/mysql/restore_backup_uc.go @@ -36,6 +36,7 @@ type RestoreMysqlBackupUsecase struct { } func (uc *RestoreMysqlBackupUsecase) Execute( + parentCtx context.Context, originalDB *databases.Database, restoringToDB *databases.Database, backupConfig *backups_config.BackupConfig, @@ -78,6 +79,7 @@ func (uc *RestoreMysqlBackupUsecase) Execute( } return uc.restoreFromStorage( + parentCtx, originalDB, tools.GetMysqlExecutable( my.Version, @@ -94,6 +96,7 @@ func (uc *RestoreMysqlBackupUsecase) Execute( } func (uc *RestoreMysqlBackupUsecase) restoreFromStorage( + parentCtx context.Context, database *databases.Database, mysqlBin string, args []string, @@ -102,7 +105,7 @@ func (uc *RestoreMysqlBackupUsecase) restoreFromStorage( storage *storages.Storage, myConfig *mysqltypes.MysqlDatabase, ) error { - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute) + ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute) defer cancel() go func() { @@ -112,6 +115,9 @@ func (uc *RestoreMysqlBackupUsecase) restoreFromStorage( select { case <-ctx.Done(): return + case <-parentCtx.Done(): + cancel() + return case <-ticker.C: if config.IsShouldShutdown() { cancel() @@ -204,6 +210,15 @@ func (uc *RestoreMysqlBackupUsecase) executeMysqlRestore( waitErr := cmd.Wait() stderrOutput := <-stderrCh + // Check for cancellation + select { + case <-ctx.Done(): + if errors.Is(ctx.Err(), context.Canceled) { + return fmt.Errorf("restore cancelled") + } + default: + } + if config.IsShouldShutdown() { return fmt.Errorf("restore cancelled due to shutdown") } diff --git a/backend/internal/features/restores/usecases/postgresql/restore_backup_uc.go b/backend/internal/features/restores/usecases/postgresql/restore_backup_uc.go index b873db2..dcc1999 100644 --- a/backend/internal/features/restores/usecases/postgresql/restore_backup_uc.go +++ b/backend/internal/features/restores/usecases/postgresql/restore_backup_uc.go @@ -35,6 +35,7 @@ type RestorePostgresqlBackupUsecase struct { } func (uc *RestorePostgresqlBackupUsecase) Execute( + parentCtx context.Context, originalDB *databases.Database, restoringToDB *databases.Database, backupConfig *backups_config.BackupConfig, @@ -73,6 +74,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute( // All PostgreSQL backups are now custom format (-Fc) return uc.restoreCustomType( + parentCtx, originalDB, pgBin, backup, @@ -84,6 +86,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute( // restoreCustomType restores a backup in custom type (-Fc) func (uc *RestorePostgresqlBackupUsecase) restoreCustomType( + parentCtx context.Context, originalDB *databases.Database, pgBin string, backup *backups_core.Backup, @@ -102,15 +105,24 @@ func (uc *RestorePostgresqlBackupUsecase) restoreCustomType( // If excluding extensions, we must use file-based restore (requires TOC file generation) // Also use file-based restore for parallel jobs (multiple CPUs) if isExcludeExtensions || pg.CpuCount > 1 { - return uc.restoreViaFile(originalDB, pgBin, backup, storage, pg, isExcludeExtensions) + return uc.restoreViaFile( + parentCtx, + originalDB, + pgBin, + backup, + storage, + pg, + isExcludeExtensions, + ) } // Single CPU without extension exclusion: stream directly via stdin - return uc.restoreViaStdin(originalDB, pgBin, backup, storage, pg) + return uc.restoreViaStdin(parentCtx, originalDB, pgBin, backup, storage, pg) } // restoreViaStdin streams backup via stdin for single CPU restore func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin( + parentCtx context.Context, originalDB *databases.Database, pgBin string, backup *backups_core.Backup, @@ -133,10 +145,10 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin( "--no-acl", } - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute) + ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute) defer cancel() - // Monitor for shutdown and cancel context if needed + // Monitor for shutdown and parent cancellation go func() { ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() @@ -145,6 +157,9 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin( select { case <-ctx.Done(): return + case <-parentCtx.Done(): + cancel() + return case <-ticker.C: if config.IsShouldShutdown() { cancel() @@ -296,6 +311,15 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin( stderrOutput := <-stderrCh copyErr := <-copyErrCh + // Check for cancellation + select { + case <-ctx.Done(): + if errors.Is(ctx.Err(), context.Canceled) { + return fmt.Errorf("restore cancelled") + } + default: + } + // Check for shutdown before finalizing if config.IsShouldShutdown() { return fmt.Errorf("restore cancelled due to shutdown") @@ -307,6 +331,15 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin( } if waitErr != nil { + // Check for cancellation again + select { + case <-ctx.Done(): + if errors.Is(ctx.Err(), context.Canceled) { + return fmt.Errorf("restore cancelled") + } + default: + } + if config.IsShouldShutdown() { return fmt.Errorf("restore cancelled due to shutdown") } @@ -319,6 +352,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin( // restoreViaFile downloads backup and uses parallel jobs for multi-CPU restore func (uc *RestorePostgresqlBackupUsecase) restoreViaFile( + parentCtx context.Context, originalDB *databases.Database, pgBin string, backup *backups_core.Backup, @@ -354,6 +388,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaFile( } return uc.restoreFromStorage( + parentCtx, originalDB, pgBin, args, @@ -367,6 +402,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaFile( // restoreFromStorage restores backup data from storage using pg_restore func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage( + parentCtx context.Context, database *databases.Database, pgBin string, args []string, @@ -386,10 +422,10 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage( isExcludeExtensions, ) - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute) + ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute) defer cancel() - // Monitor for shutdown and cancel context if needed + // Monitor for shutdown and parent cancellation go func() { ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() @@ -398,6 +434,9 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage( select { case <-ctx.Done(): return + case <-parentCtx.Done(): + cancel() + return case <-ticker.C: if config.IsShouldShutdown() { cancel() @@ -624,12 +663,30 @@ func (uc *RestorePostgresqlBackupUsecase) executePgRestore( waitErr := cmd.Wait() stderrOutput := <-stderrCh + // Check for cancellation + select { + case <-ctx.Done(): + if errors.Is(ctx.Err(), context.Canceled) { + return fmt.Errorf("restore cancelled") + } + default: + } + // Check for shutdown before finalizing if config.IsShouldShutdown() { return fmt.Errorf("restore cancelled due to shutdown") } if waitErr != nil { + // Check for cancellation again + select { + case <-ctx.Done(): + if errors.Is(ctx.Err(), context.Canceled) { + return fmt.Errorf("restore cancelled") + } + default: + } + if config.IsShouldShutdown() { return fmt.Errorf("restore cancelled due to shutdown") } diff --git a/backend/internal/features/restores/usecases/restore_backup_uc.go b/backend/internal/features/restores/usecases/restore_backup_uc.go index 13814d0..c205f2f 100644 --- a/backend/internal/features/restores/usecases/restore_backup_uc.go +++ b/backend/internal/features/restores/usecases/restore_backup_uc.go @@ -1,6 +1,7 @@ package usecases import ( + "context" "errors" backups_core "databasus-backend/internal/features/backups/backups/core" @@ -22,6 +23,7 @@ type RestoreBackupUsecase struct { } func (uc *RestoreBackupUsecase) Execute( + ctx context.Context, backupConfig *backups_config.BackupConfig, restore restores_core.Restore, originalDB *databases.Database, @@ -33,6 +35,7 @@ func (uc *RestoreBackupUsecase) Execute( switch originalDB.Type { case databases.DatabaseTypePostgres: return uc.restorePostgresqlBackupUsecase.Execute( + ctx, originalDB, restoringToDB, backupConfig, @@ -43,6 +46,7 @@ func (uc *RestoreBackupUsecase) Execute( ) case databases.DatabaseTypeMysql: return uc.restoreMysqlBackupUsecase.Execute( + ctx, originalDB, restoringToDB, backupConfig, @@ -52,6 +56,7 @@ func (uc *RestoreBackupUsecase) Execute( ) case databases.DatabaseTypeMariadb: return uc.restoreMariadbBackupUsecase.Execute( + ctx, originalDB, restoringToDB, backupConfig, @@ -61,6 +66,7 @@ func (uc *RestoreBackupUsecase) Execute( ) case databases.DatabaseTypeMongodb: return uc.restoreMongodbBackupUsecase.Execute( + ctx, originalDB, restoringToDB, backupConfig, diff --git a/backend/internal/features/storages/di.go b/backend/internal/features/storages/di.go index 522b9e9..57822b8 100644 --- a/backend/internal/features/storages/di.go +++ b/backend/internal/features/storages/di.go @@ -1,9 +1,13 @@ package storages import ( + "sync" + "sync/atomic" + audit_logs "databasus-backend/internal/features/audit_logs" workspaces_services "databasus-backend/internal/features/workspaces/services" "databasus-backend/internal/util/encryption" + "databasus-backend/internal/util/logger" ) var storageRepository = &StorageRepository{} @@ -27,6 +31,21 @@ func GetStorageController() *StorageController { return storageController } +var ( + setupOnce sync.Once + isSetup atomic.Bool +) + func SetupDependencies() { - workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(storageService) + wasAlreadySetup := isSetup.Load() + + setupOnce.Do(func() { + workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(storageService) + + isSetup.Store(true) + }) + + if wasAlreadySetup { + logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call") + } } diff --git a/backend/internal/features/tasks/cancellation/di.go b/backend/internal/features/tasks/cancellation/di.go index 5d1c83a..1763521 100644 --- a/backend/internal/features/tasks/cancellation/di.go +++ b/backend/internal/features/tasks/cancellation/di.go @@ -2,9 +2,11 @@ package task_cancellation import ( "context" + "sync" + "sync/atomic" + cache_utils "databasus-backend/internal/util/cache" "databasus-backend/internal/util/logger" - "sync" "github.com/google/uuid" ) @@ -20,6 +22,21 @@ func GetTaskCancelManager() *TaskCancelManager { return taskCancelManager } +var ( + setupOnce sync.Once + isSetup atomic.Bool +) + func SetupDependencies() { - taskCancelManager.StartSubscription() + wasAlreadySetup := isSetup.Load() + + setupOnce.Do(func() { + taskCancelManager.StartSubscription() + + isSetup.Store(true) + }) + + if wasAlreadySetup { + logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call") + } } diff --git a/backend/internal/features/test_once_protection.go b/backend/internal/features/test_once_protection.go new file mode 100644 index 0000000..7b096f8 --- /dev/null +++ b/backend/internal/features/test_once_protection.go @@ -0,0 +1,159 @@ +package features + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "databasus-backend/internal/features/audit_logs" + "databasus-backend/internal/features/backups/backups" + "databasus-backend/internal/features/backups/backups/backuping" + backups_config "databasus-backend/internal/features/backups/config" + "databasus-backend/internal/features/databases" + healthcheck_config "databasus-backend/internal/features/healthcheck/config" + "databasus-backend/internal/features/notifiers" + "databasus-backend/internal/features/restores" + "databasus-backend/internal/features/restores/restoring" + "databasus-backend/internal/features/storages" + task_cancellation "databasus-backend/internal/features/tasks/cancellation" +) + +// Test_SetupDependencies_CalledTwice_LogsWarning verifies SetupDependencies is idempotent +func Test_SetupDependencies_CalledTwice_LogsWarning(t *testing.T) { + // Call each SetupDependencies twice - should not panic, only log warnings + audit_logs.SetupDependencies() + audit_logs.SetupDependencies() + + backups.SetupDependencies() + backups.SetupDependencies() + + backups_config.SetupDependencies() + backups_config.SetupDependencies() + + databases.SetupDependencies() + databases.SetupDependencies() + + healthcheck_config.SetupDependencies() + healthcheck_config.SetupDependencies() + + notifiers.SetupDependencies() + notifiers.SetupDependencies() + + restores.SetupDependencies() + restores.SetupDependencies() + + storages.SetupDependencies() + storages.SetupDependencies() + + task_cancellation.SetupDependencies() + task_cancellation.SetupDependencies() + + // If we reach here without panic, test passes + t.Log("All SetupDependencies calls completed successfully (idempotent)") +} + +// Test_SetupDependencies_ConcurrentCalls_Safe verifies thread safety +func Test_SetupDependencies_ConcurrentCalls_Safe(t *testing.T) { + var wg sync.WaitGroup + + // Call SetupDependencies concurrently from 10 goroutines + for range 10 { + wg.Add(1) + go func() { + defer wg.Done() + audit_logs.SetupDependencies() + }() + } + + wg.Wait() + t.Log("Concurrent SetupDependencies calls completed successfully") +} + +// Test_BackgroundService_Run_CalledTwice_Panics verifies Run() panics on duplicate calls +func Test_BackgroundService_Run_CalledTwice_Panics(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Create a test background service + backgroundService := audit_logs.GetAuditLogBackgroundService() + + // Start first Run() in goroutine + go func() { + backgroundService.Run(ctx) + }() + + // Give first call time to initialize + time.Sleep(100 * time.Millisecond) + + // Second call should panic + defer func() { + if r := recover(); r != nil { + expectedMsg := "*audit_logs.AuditLogBackgroundService.Run() called multiple times" + panicMsg := fmt.Sprintf("%v", r) + if panicMsg == expectedMsg { + t.Logf("Successfully caught panic: %v", r) + } else { + t.Errorf("Expected panic message '%s', got '%s'", expectedMsg, panicMsg) + } + } else { + t.Error("Expected panic on second Run() call, but did not panic") + } + }() + + backgroundService.Run(ctx) +} + +// Test_BackupsScheduler_Run_CalledTwice_Panics verifies scheduler panics on duplicate calls +func Test_BackupsScheduler_Run_CalledTwice_Panics(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + scheduler := backuping.GetBackupsScheduler() + + // Start first Run() in goroutine + go func() { + scheduler.Run(ctx) + }() + + // Give first call time to initialize + time.Sleep(100 * time.Millisecond) + + // Second call should panic + defer func() { + if r := recover(); r != nil { + t.Logf("Successfully caught panic: %v", r) + } else { + t.Error("Expected panic on second Run() call, but did not panic") + } + }() + + scheduler.Run(ctx) +} + +// Test_RestoresScheduler_Run_CalledTwice_Panics verifies restore scheduler panics on duplicate calls +func Test_RestoresScheduler_Run_CalledTwice_Panics(t *testing.T) { + ctx := t.Context() + + scheduler := restoring.GetRestoresScheduler() + + // Start first Run() in goroutine + go func() { + scheduler.Run(ctx) + }() + + // Give first call time to initialize + time.Sleep(100 * time.Millisecond) + + // Second call should panic + defer func() { + if r := recover(); r != nil { + t.Logf("Successfully caught panic: %v", r) + } else { + t.Error("Expected panic on second Run() call, but did not panic") + } + }() + + scheduler.Run(ctx) +} diff --git a/frontend/src/entity/restores/api/restoreApi.ts b/frontend/src/entity/restores/api/restoreApi.ts index 119ac08..6334d07 100644 --- a/frontend/src/entity/restores/api/restoreApi.ts +++ b/frontend/src/entity/restores/api/restoreApi.ts @@ -46,4 +46,8 @@ export const restoreApi = { requestOptions, ); }, + + async cancelRestore(restoreId: string) { + return apiHelper.fetchPostRaw(`${getApplicationServer()}/api/v1/restores/cancel/${restoreId}`); + }, }; diff --git a/frontend/src/entity/restores/model/RestoreStatus.ts b/frontend/src/entity/restores/model/RestoreStatus.ts index 35dbc5a..2038ecd 100644 --- a/frontend/src/entity/restores/model/RestoreStatus.ts +++ b/frontend/src/entity/restores/model/RestoreStatus.ts @@ -2,4 +2,5 @@ export enum RestoreStatus { IN_PROGRESS = 'IN_PROGRESS', COMPLETED = 'COMPLETED', FAILED = 'FAILED', + CANCELED = 'CANCELED', } diff --git a/frontend/src/features/restores/ui/RestoresComponent.tsx b/frontend/src/features/restores/ui/RestoresComponent.tsx index 11c361e..0c080b2 100644 --- a/frontend/src/features/restores/ui/RestoresComponent.tsx +++ b/frontend/src/features/restores/ui/RestoresComponent.tsx @@ -1,5 +1,10 @@ -import { CopyOutlined, ExclamationCircleOutlined, SyncOutlined } from '@ant-design/icons'; -import { CheckCircleOutlined } from '@ant-design/icons'; +import { + CheckCircleOutlined, + CloseCircleOutlined, + CopyOutlined, + ExclamationCircleOutlined, + SyncOutlined, +} from '@ant-design/icons'; import { App, Button, Modal, Spin, Tooltip } from 'antd'; import dayjs from 'dayjs'; import { useEffect, useRef, useState } from 'react'; @@ -8,6 +13,7 @@ import type { Backup } from '../../../entity/backups'; import { type Database, DatabaseType } from '../../../entity/databases'; import { type Restore, RestoreStatus, restoreApi } from '../../../entity/restores'; import { getUserTimeFormat } from '../../../shared/time'; +import { ConfirmationComponent } from '../../../shared/ui'; import { EditDatabaseSpecificDataComponent } from '../../databases/ui/edit/EditDatabaseSpecificDataComponent'; interface Props { @@ -70,6 +76,10 @@ export const RestoresComponent = ({ database, backup }: Props) => { const [isShowRestore, setIsShowRestore] = useState(false); + const [cancellingRestoreId, setCancellingRestoreId] = useState(); + const [showCancelConfirmation, setShowCancelConfirmation] = useState(false); + const [restoreToCancelId, setRestoreToCancelId] = useState(); + const isReloadInProgress = useRef(false); const loadRestores = async () => { @@ -103,6 +113,18 @@ export const RestoresComponent = ({ database, backup }: Props) => { } }; + const cancelRestore = async (restoreId: string) => { + setCancellingRestoreId(restoreId); + try { + await restoreApi.cancelRestore(restoreId); + await loadRestores(); + } catch (e) { + alert((e as Error).message); + } finally { + setCancellingRestoreId(undefined); + } + }; + useEffect(() => { setIsLoading(true); loadRestores().finally(() => setIsLoading(false)); @@ -190,40 +212,77 @@ export const RestoresComponent = ({ database, backup }: Props) => { return (
-
-
Status
+
+
+
Status
- {restore.status === RestoreStatus.FAILED && ( - -
setShowingRestoreError(restore)} - > - +
setShowingRestoreError(restore)} + > + + +
Failed
+
+ + )} + + {restore.status === RestoreStatus.COMPLETED && ( +
+ -
Failed
+
Successful
- - )} + )} - {restore.status === RestoreStatus.COMPLETED && ( -
- + {restore.status === RestoreStatus.CANCELED && ( +
+ -
Successful
-
- )} +
Canceled
+
+ )} + + {restore.status === RestoreStatus.IN_PROGRESS && ( +
+ + In progress +
+ )} +
{restore.status === RestoreStatus.IN_PROGRESS && ( -
- - In progress +
+ {cancellingRestoreId === restore.id ? ( + + ) : ( + + { + if (cancellingRestoreId) return; + setRestoreToCancelId(restore.id); + setShowCancelConfirmation(true); + }} + style={{ + color: '#ff0000', + fontSize: 16, + opacity: cancellingRestoreId ? 0.2 : 1, + }} + /> + + )}
)}
@@ -289,6 +348,25 @@ export const RestoresComponent = ({ database, backup }: Props) => {
)} + + {showCancelConfirmation && ( + { + setShowCancelConfirmation(false); + if (restoreToCancelId) { + cancelRestore(restoreToCancelId); + } + setRestoreToCancelId(undefined); + }} + onDecline={() => { + setShowCancelConfirmation(false); + setRestoreToCancelId(undefined); + }} + description="⚠️ Warning: Cancelling this restore will likely leave your database in a corrupted or incomplete state. You will need to recreate the database before attempting another restore.

Are you sure you want to cancel?" + actionText="Yes, cancel restore" + actionButtonColor="red" + /> + )}
); };