diff --git a/AGENTS.md b/AGENTS.md index 1bd7e4c..ce8f642 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -7,6 +7,7 @@ This is NOT a strict set of rules, but a set of recommendations to help you writ ## Table of Contents +- [Engineering Philosophy](#engineering-philosophy) - [Backend Guidelines](#backend-guidelines) - [Code Style](#code-style) - [Comments](#comments) @@ -22,6 +23,67 @@ This is NOT a strict set of rules, but a set of recommendations to help you writ --- +## Engineering Philosophy + +**Think like a skeptical senior engineer and code reviewer. Don't just do what was asked—also think about what should have been asked.** + +⚠️ **Balance vigilance with pragmatism:** Catch real issues, not theoretical ones. Don't let perfect be the enemy of good. + +### Task Context Assessment: + +**First, assess the task scope:** + +- **Trivial** (typos, formatting, simple field adds): Apply directly with minimal analysis +- **Standard** (CRUD, typical features): Brief assumption check, proceed +- **Complex** (architecture, security, performance-critical): Full analysis required +- **Unclear** (ambiguous requirements): Always clarify assumptions first + +### For Non-Trivial Tasks: + +1. **Restate the objective and list assumptions** (explicit + implicit) + - If any assumption is shaky, call it out clearly + - Distinguish between what's specified and what you're inferring + +2. **Propose appropriate solutions:** + - For complex tasks: 2–3 viable approaches (including a simpler baseline) + - Recommend one with clear tradeoffs + - Consider: complexity, maintainability, performance, future extensibility + +3. **Identify risks proactively:** + - Edge cases and boundary conditions + - Security/privacy pitfalls + - Performance risks and scalability concerns + - Operational concerns (deployment, observability, rollback, monitoring) + +4. **Handle ambiguity:** + - If requirements are ambiguous, make a reasonable default and proceed + - Clearly label your assumptions + - Document what would change under alternative assumptions + +5. **Deliver quality:** + - Provide a solution that is correct, testable, and maintainable + - Include minimal tests or validation steps + - Follow project testing philosophy: prefer controller tests over unit tests + - Follow all project guidelines from this document + +6. **Self-review before finalizing:** + - Ask: "What could go wrong?" + - Patch the answer accordingly + - Verify edge cases are handled + +### Application Guidelines: + +**Scale your response to the task:** + +- **Trivial changes:** Steps 5-6 only (deliver quality + self-review) +- **Standard features:** Steps 1, 5-6 (restate + deliver + review) +- **Complex/risky changes:** All steps 1-6 +- **Ambiguous requests:** Steps 1, 4 mandatory + +**Be proportionally thorough—brief for simple tasks, comprehensive for risky ones. Avoid analysis paralysis.** + +--- + ## Backend Guidelines ### Code Style diff --git a/backend/cmd/main.go b/backend/cmd/main.go index bdc3ac1..bec1b38 100644 --- a/backend/cmd/main.go +++ b/backend/cmd/main.go @@ -25,10 +25,10 @@ import ( healthcheck_config "databasus-backend/internal/features/healthcheck/config" "databasus-backend/internal/features/notifiers" "databasus-backend/internal/features/restores" + "databasus-backend/internal/features/restores/restoring" "databasus-backend/internal/features/storages" system_healthcheck "databasus-backend/internal/features/system/healthcheck" task_cancellation "databasus-backend/internal/features/tasks/cancellation" - task_registry "databasus-backend/internal/features/tasks/registry" users_controllers "databasus-backend/internal/features/users/controllers" users_middleware "databasus-backend/internal/features/users/middleware" users_services "databasus-backend/internal/features/users/services" @@ -273,7 +273,7 @@ func runBackgroundTasks(log *slog.Logger) { }) go runWithPanicLogging(log, "restore background service", func() { - restores.GetRestoreBackgroundService().Run(ctx) + restoring.GetRestoresScheduler().Run(ctx) }) go runWithPanicLogging(log, "healthcheck attempt background service", func() { @@ -288,21 +288,29 @@ func runBackgroundTasks(log *slog.Logger) { backups_download.GetDownloadTokenBackgroundService().Run(ctx) }) - go runWithPanicLogging(log, "task nodes registry background service", func() { - task_registry.GetTaskNodesRegistry().Run(ctx) + go runWithPanicLogging(log, "backup nodes registry background service", func() { + backuping.GetBackupNodesRegistry().Run(ctx) + }) + + go runWithPanicLogging(log, "restore nodes registry background service", func() { + restoring.GetRestoreNodesRegistry().Run(ctx) }) } else { log.Info("Skipping primary node tasks as not primary node") } - if config.GetEnv().IsBackupNode { + if config.GetEnv().IsProcessingNode { log.Info("Starting backup node background tasks...") go runWithPanicLogging(log, "backup node", func() { backuping.GetBackuperNode().Run(ctx) }) + + go runWithPanicLogging(log, "restore node", func() { + restoring.GetRestorerNode().Run(ctx) + }) } else { - log.Info("Skipping backup node tasks as not backup node") + log.Info("Skipping backup/restore node tasks as not backup node") } } diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 5b4aa14..38e5906 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -9,7 +9,6 @@ import ( "strings" "sync" - "github.com/google/uuid" "github.com/ilyakaznacheev/cleanenv" "github.com/joho/godotenv" ) @@ -32,10 +31,9 @@ type EnvVariables struct { ShowDbInstallationVerificationLogs bool `env:"SHOW_DB_INSTALLATION_VERIFICATION_LOGS"` - NodeID string IsManyNodesMode bool `env:"IS_MANY_NODES_MODE"` IsPrimaryNode bool `env:"IS_PRIMARY_NODE"` - IsBackupNode bool `env:"IS_BACKUP_NODE"` + IsProcessingNode bool `env:"IS_PROCESSING_NODE"` NodeNetworkThroughputMBs int `env:"NODE_NETWORK_THROUGHPUT_MBPS"` DataFolder string @@ -230,14 +228,13 @@ func loadEnvVariables() { env.ShowDbInstallationVerificationLogs, ) - env.NodeID = uuid.New().String() if env.NodeNetworkThroughputMBs == 0 { env.NodeNetworkThroughputMBs = 125 // 1 Gbit/s } if !env.IsManyNodesMode { env.IsPrimaryNode = true - env.IsBackupNode = true + env.IsProcessingNode = true } // Valkey diff --git a/backend/internal/features/backups/backups/backuping/backuper.go b/backend/internal/features/backups/backups/backuping/backuper.go index 222ee26..2c40a4d 100644 --- a/backend/internal/features/backups/backups/backuping/backuper.go +++ b/backend/internal/features/backups/backups/backuping/backuper.go @@ -8,7 +8,6 @@ import ( "databasus-backend/internal/features/databases" "databasus-backend/internal/features/storages" tasks_cancellation "databasus-backend/internal/features/tasks/cancellation" - task_registry "databasus-backend/internal/features/tasks/registry" workspaces_services "databasus-backend/internal/features/workspaces/services" util_encryption "databasus-backend/internal/util/encryption" "errors" @@ -35,7 +34,7 @@ type BackuperNode struct { storageService *storages.StorageService notificationSender backups_core.NotificationSender backupCancelManager *tasks_cancellation.TaskCancelManager - tasksRegistry *task_registry.TaskNodesRegistry + backupNodesRegistry *BackupNodesRegistry logger *slog.Logger createBackupUseCase backups_core.CreateBackupUsecase nodeID uuid.UUID @@ -48,19 +47,20 @@ func (n *BackuperNode) Run(ctx context.Context) { throughputMBs := config.GetEnv().NodeNetworkThroughputMBs - backupNode := task_registry.TaskNode{ + backupNode := BackupNode{ ID: n.nodeID, ThroughputMBs: throughputMBs, + LastHeartbeat: time.Now().UTC(), } - if err := n.tasksRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil { + if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil { n.logger.Error("Failed to register node in registry", "error", err) panic(err) } backupHandler := func(backupID uuid.UUID, isCallNotifier bool) { n.MakeBackup(backupID, isCallNotifier) - if err := n.tasksRegistry.PublishTaskCompletion(n.nodeID.String(), backupID); err != nil { + if err := n.backupNodesRegistry.PublishBackupCompletion(n.nodeID, backupID); err != nil { n.logger.Error( "Failed to publish backup completion", "error", @@ -71,12 +71,13 @@ func (n *BackuperNode) Run(ctx context.Context) { } } - if err := n.tasksRegistry.SubscribeNodeForTasksAssignment(n.nodeID.String(), backupHandler); err != nil { + err := n.backupNodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID, backupHandler) + if err != nil { n.logger.Error("Failed to subscribe to backup assignments", "error", err) panic(err) } defer func() { - if err := n.tasksRegistry.UnsubscribeNodeForTasksAssignments(); err != nil { + if err := n.backupNodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil { n.logger.Error("Failed to unsubscribe from backup assignments", "error", err) } }() @@ -91,7 +92,7 @@ func (n *BackuperNode) Run(ctx context.Context) { case <-ctx.Done(): n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID) - if err := n.tasksRegistry.UnregisterNodeFromRegistry(backupNode); err != nil { + if err := n.backupNodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil { n.logger.Error("Failed to unregister node from registry", "error", err) } @@ -357,9 +358,9 @@ func (n *BackuperNode) SendBackupNotification( } } -func (n *BackuperNode) sendHeartbeat(backupNode *task_registry.TaskNode) { +func (n *BackuperNode) sendHeartbeat(backupNode *BackupNode) { n.lastHeartbeat = time.Now().UTC() - if err := n.tasksRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *backupNode); err != nil { + if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *backupNode); err != nil { n.logger.Error("Failed to send heartbeat", "error", err) } } diff --git a/backend/internal/features/backups/backups/backuping/di.go b/backend/internal/features/backups/backups/backuping/di.go index ff2b342..12fa010 100644 --- a/backend/internal/features/backups/backups/backuping/di.go +++ b/backend/internal/features/backups/backups/backuping/di.go @@ -1,7 +1,6 @@ package backuping import ( - "databasus-backend/internal/config" backups_core "databasus-backend/internal/features/backups/backups/core" "databasus-backend/internal/features/backups/backups/usecases" backups_config "databasus-backend/internal/features/backups/config" @@ -9,8 +8,8 @@ import ( "databasus-backend/internal/features/notifiers" "databasus-backend/internal/features/storages" tasks_cancellation "databasus-backend/internal/features/tasks/cancellation" - task_registry "databasus-backend/internal/features/tasks/registry" workspaces_services "databasus-backend/internal/features/workspaces/services" + cache_utils "databasus-backend/internal/util/cache" "databasus-backend/internal/util/encryption" "databasus-backend/internal/util/logger" "time" @@ -22,16 +21,16 @@ var backupRepository = &backups_core.BackupRepository{} var taskCancelManager = tasks_cancellation.GetTaskCancelManager() -var nodesRegistry = task_registry.GetTaskNodesRegistry() +var backupNodesRegistry = &BackupNodesRegistry{ + cache_utils.GetValkeyClient(), + logger.GetLogger(), + cache_utils.DefaultCacheTimeout, + cache_utils.NewPubSubManager(), + cache_utils.NewPubSubManager(), +} func getNodeID() uuid.UUID { - nodeIDStr := config.GetEnv().NodeID - nodeID, err := uuid.Parse(nodeIDStr) - if err != nil { - logger.GetLogger().Error("Failed to parse node ID from config", "error", err) - panic(err) - } - return nodeID + return uuid.New() } var backuperNode = &BackuperNode{ @@ -43,7 +42,7 @@ var backuperNode = &BackuperNode{ storages.GetStorageService(), notifiers.GetNotifierService(), taskCancelManager, - nodesRegistry, + backupNodesRegistry, logger.GetLogger(), usecases.GetCreateBackupUsecase(), getNodeID(), @@ -51,15 +50,15 @@ var backuperNode = &BackuperNode{ } var backupsScheduler = &BackupsScheduler{ - backupRepository, - backups_config.GetBackupConfigService(), - storages.GetStorageService(), - taskCancelManager, - nodesRegistry, - time.Now().UTC(), - logger.GetLogger(), - make(map[uuid.UUID]BackupToNodeRelation), - backuperNode, + backupRepository: backupRepository, + backupConfigService: backups_config.GetBackupConfigService(), + storageService: storages.GetStorageService(), + taskCancelManager: taskCancelManager, + backupNodesRegistry: backupNodesRegistry, + lastBackupTime: time.Now().UTC(), + logger: logger.GetLogger(), + backupToNodeRelations: make(map[uuid.UUID]BackupToNodeRelation), + backuperNode: backuperNode, } func GetBackupsScheduler() *BackupsScheduler { @@ -69,3 +68,7 @@ func GetBackupsScheduler() *BackupsScheduler { func GetBackuperNode() *BackuperNode { return backuperNode } + +func GetBackupNodesRegistry() *BackupNodesRegistry { + return backupNodesRegistry +} diff --git a/backend/internal/features/backups/backups/backuping/dto.go b/backend/internal/features/backups/backups/backuping/dto.go index c8f2fa3..93fd822 100644 --- a/backend/internal/features/backups/backups/backuping/dto.go +++ b/backend/internal/features/backups/backups/backuping/dto.go @@ -1,8 +1,34 @@ package backuping -import "github.com/google/uuid" +import ( + "time" + + "github.com/google/uuid" +) type BackupToNodeRelation struct { NodeID uuid.UUID `json:"nodeId"` BackupsIDs []uuid.UUID `json:"backupsIds"` } + +type BackupNode struct { + ID uuid.UUID `json:"id"` + ThroughputMBs int `json:"throughputMBs"` + LastHeartbeat time.Time `json:"lastHeartbeat"` +} + +type BackupNodeStats struct { + ID uuid.UUID `json:"id"` + ActiveBackups int `json:"activeBackups"` +} + +type BackupSubmitMessage struct { + NodeID uuid.UUID `json:"nodeId"` + BackupID uuid.UUID `json:"backupId"` + IsCallNotifier bool `json:"isCallNotifier"` +} + +type BackupCompletionMessage struct { + NodeID uuid.UUID `json:"nodeId"` + BackupID uuid.UUID `json:"backupId"` +} diff --git a/backend/internal/features/tasks/registry/registry.go b/backend/internal/features/backups/backups/backuping/registry.go similarity index 63% rename from backend/internal/features/tasks/registry/registry.go rename to backend/internal/features/backups/backups/backuping/registry.go index e80cb87..53365a5 100644 --- a/backend/internal/features/tasks/registry/registry.go +++ b/backend/internal/features/backups/backups/backuping/registry.go @@ -1,4 +1,4 @@ -package task_registry +package backuping import ( "context" @@ -15,45 +15,41 @@ import ( ) const ( - nodeInfoKeyPrefix = "node:" - nodeInfoKeySuffix = ":info" - nodeActiveTasksPrefix = "node:" - nodeActiveTasksSuffix = ":active_tasks" - taskSubmitChannel = "task:submit" - taskCompletionChannel = "task:completion" + nodeInfoKeyPrefix = "backup:node:" + nodeInfoKeySuffix = ":info" + nodeActiveBackupsPrefix = "backup:node:" + nodeActiveBackupsSuffix = ":active_backups" + backupSubmitChannel = "backup:submit" + backupCompletionChannel = "backup:completion" deadNodeThreshold = 2 * time.Minute cleanupTickerInterval = 1 * time.Second ) -// TaskNodesRegistry helps to sync tasks scheduler (backuping or restoring) -// and task nodes which are used for network-intensive tasks processing +// BackupNodesRegistry helps to sync backups scheduler and backup nodes. // // Features: // - Track node availability and load level -// - Assign from scheduler to node tasks needed to be processed -// - Notify scheduler from node about task completion +// - Assign from scheduler to node backups needed to be processed +// - Notify scheduler from node about backup completion // // Important things to remember: -// - Node can contain different tasks types so when task is assigned -// or node's tasks cleaned - should be performed DB check in DB -// that task with this ID exists for this task type at all -// - Nodes without heathbeat for more than 2 minutes are not included +// - Nodes without heartbeat for more than 2 minutes are not included // in available nodes list and stats // // Cleanup dead nodes performed on 2 levels: // - List and stats functions do not return dead nodes // - Periodically dead nodes are cleaned up in cache (to not // accumulate too many dead nodes in cache) -type TaskNodesRegistry struct { +type BackupNodesRegistry struct { client valkey.Client logger *slog.Logger timeout time.Duration - pubsubTasks *cache_utils.PubSubManager + pubsubBackups *cache_utils.PubSubManager pubsubCompletions *cache_utils.PubSubManager } -func (r *TaskNodesRegistry) Run(ctx context.Context) { +func (r *BackupNodesRegistry) Run(ctx context.Context) { if err := r.cleanupDeadNodes(); err != nil { r.logger.Error("Failed to cleanup dead nodes on startup", "error", err) } @@ -72,7 +68,7 @@ func (r *TaskNodesRegistry) Run(ctx context.Context) { } } -func (r *TaskNodesRegistry) GetAvailableNodes() ([]TaskNode, error) { +func (r *BackupNodesRegistry) GetAvailableNodes() ([]BackupNode, error) { ctx, cancel := context.WithTimeout(context.Background(), r.timeout) defer cancel() @@ -104,7 +100,7 @@ func (r *TaskNodesRegistry) GetAvailableNodes() ([]TaskNode, error) { } if len(allKeys) == 0 { - return []TaskNode{}, nil + return []BackupNode{}, nil } keyDataMap, err := r.pipelineGetKeys(allKeys) @@ -113,14 +109,15 @@ func (r *TaskNodesRegistry) GetAvailableNodes() ([]TaskNode, error) { } threshold := time.Now().UTC().Add(-deadNodeThreshold) - var nodes []TaskNode + var nodes []BackupNode + for key, data := range keyDataMap { // Skip if the key doesn't exist (data is empty) if len(data) == 0 { continue } - var node TaskNode + var node BackupNode if err := json.Unmarshal(data, &node); err != nil { r.logger.Warn("Failed to unmarshal node data", "key", key, "error", err) continue @@ -141,13 +138,13 @@ func (r *TaskNodesRegistry) GetAvailableNodes() ([]TaskNode, error) { return nodes, nil } -func (r *TaskNodesRegistry) GetNodesStats() ([]TaskNodeStats, error) { +func (r *BackupNodesRegistry) GetBackupNodesStats() ([]BackupNodeStats, error) { ctx, cancel := context.WithTimeout(context.Background(), r.timeout) defer cancel() var allKeys []string cursor := uint64(0) - pattern := nodeActiveTasksPrefix + "*" + nodeActiveTasksSuffix + pattern := nodeActiveBackupsPrefix + "*" + nodeActiveBackupsSuffix for { result := r.client.Do( @@ -156,7 +153,7 @@ func (r *TaskNodesRegistry) GetNodesStats() ([]TaskNodeStats, error) { ) if result.Error() != nil { - return nil, fmt.Errorf("failed to scan active tasks keys: %w", result.Error()) + return nil, fmt.Errorf("failed to scan active backups keys: %w", result.Error()) } scanResult, err := result.AsScanEntry() @@ -173,18 +170,18 @@ func (r *TaskNodesRegistry) GetNodesStats() ([]TaskNodeStats, error) { } if len(allKeys) == 0 { - return []TaskNodeStats{}, nil + return []BackupNodeStats{}, nil } keyDataMap, err := r.pipelineGetKeys(allKeys) if err != nil { - return nil, fmt.Errorf("failed to pipeline get active tasks keys: %w", err) + return nil, fmt.Errorf("failed to pipeline get active backups keys: %w", err) } var nodeInfoKeys []string nodeIDToStatsKey := make(map[string]string) for key := range keyDataMap { - nodeID := r.extractNodeIDFromKey(key, nodeActiveTasksPrefix, nodeActiveTasksSuffix) + nodeID := r.extractNodeIDFromKey(key, nodeActiveBackupsPrefix, nodeActiveBackupsSuffix) nodeIDStr := nodeID.String() infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeIDStr, nodeInfoKeySuffix) nodeInfoKeys = append(nodeInfoKeys, infoKey) @@ -197,14 +194,14 @@ func (r *TaskNodesRegistry) GetNodesStats() ([]TaskNodeStats, error) { } threshold := time.Now().UTC().Add(-deadNodeThreshold) - var stats []TaskNodeStats + var stats []BackupNodeStats for infoKey, nodeData := range nodeInfoMap { // Skip if the info key doesn't exist (nodeData is empty) if len(nodeData) == 0 { continue } - var node TaskNode + var node BackupNode if err := json.Unmarshal(nodeData, &node); err != nil { r.logger.Warn("Failed to unmarshal node data", "key", infoKey, "error", err) continue @@ -223,13 +220,13 @@ func (r *TaskNodesRegistry) GetNodesStats() ([]TaskNodeStats, error) { tasksData := keyDataMap[statsKey] count, err := r.parseIntFromBytes(tasksData) if err != nil { - r.logger.Warn("Failed to parse active tasks count", "key", statsKey, "error", err) + r.logger.Warn("Failed to parse active backups count", "key", statsKey, "error", err) continue } - stat := TaskNodeStats{ - ID: node.ID, - ActiveTasks: int(count), + stat := BackupNodeStats{ + ID: node.ID, + ActiveBackups: int(count), } stats = append(stats, stat) } @@ -237,16 +234,16 @@ func (r *TaskNodesRegistry) GetNodesStats() ([]TaskNodeStats, error) { return stats, nil } -func (r *TaskNodesRegistry) IncrementTasksInProgress(nodeID string) error { +func (r *BackupNodesRegistry) IncrementBackupsInProgress(nodeID uuid.UUID) error { ctx, cancel := context.WithTimeout(context.Background(), r.timeout) defer cancel() - key := fmt.Sprintf("%s%s%s", nodeActiveTasksPrefix, nodeID, nodeActiveTasksSuffix) + key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID.String(), nodeActiveBackupsSuffix) result := r.client.Do(ctx, r.client.B().Incr().Key(key).Build()) if result.Error() != nil { return fmt.Errorf( - "failed to increment tasks in progress for node %s: %w", + "failed to increment backups in progress for node %s: %w", nodeID, result.Error(), ) @@ -255,16 +252,16 @@ func (r *TaskNodesRegistry) IncrementTasksInProgress(nodeID string) error { return nil } -func (r *TaskNodesRegistry) DecrementTasksInProgress(nodeID string) error { +func (r *BackupNodesRegistry) DecrementBackupsInProgress(nodeID uuid.UUID) error { ctx, cancel := context.WithTimeout(context.Background(), r.timeout) defer cancel() - key := fmt.Sprintf("%s%s%s", nodeActiveTasksPrefix, nodeID, nodeActiveTasksSuffix) + key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID.String(), nodeActiveBackupsSuffix) result := r.client.Do(ctx, r.client.B().Decr().Key(key).Build()) if result.Error() != nil { return fmt.Errorf( - "failed to decrement tasks in progress for node %s: %w", + "failed to decrement backups in progress for node %s: %w", nodeID, result.Error(), ) @@ -279,13 +276,13 @@ func (r *TaskNodesRegistry) DecrementTasksInProgress(nodeID string) error { setCtx, setCancel := context.WithTimeout(context.Background(), r.timeout) r.client.Do(setCtx, r.client.B().Set().Key(key).Value("0").Build()) setCancel() - r.logger.Warn("Active tasks counter went below 0, reset to 0", "nodeID", nodeID) + r.logger.Warn("Active backups counter went below 0, reset to 0", "nodeID", nodeID) } return nil } -func (r *TaskNodesRegistry) HearthbeatNodeInRegistry(now time.Time, node TaskNode) error { +func (r *BackupNodesRegistry) HearthbeatNodeInRegistry(now time.Time, backupNode BackupNode) error { if now.IsZero() { return fmt.Errorf("cannot register node with zero heartbeat timestamp") } @@ -293,36 +290,36 @@ func (r *TaskNodesRegistry) HearthbeatNodeInRegistry(now time.Time, node TaskNod ctx, cancel := context.WithTimeout(context.Background(), r.timeout) defer cancel() - node.LastHeartbeat = now + backupNode.LastHeartbeat = now - data, err := json.Marshal(node) + data, err := json.Marshal(backupNode) if err != nil { - return fmt.Errorf("failed to marshal node: %w", err) + return fmt.Errorf("failed to marshal backup node: %w", err) } - key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node.ID.String(), nodeInfoKeySuffix) + key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, backupNode.ID.String(), nodeInfoKeySuffix) result := r.client.Do( ctx, r.client.B().Set().Key(key).Value(string(data)).Build(), ) if result.Error() != nil { - return fmt.Errorf("failed to register node %s: %w", node.ID, result.Error()) + return fmt.Errorf("failed to register node %s: %w", backupNode.ID, result.Error()) } return nil } -func (r *TaskNodesRegistry) UnregisterNodeFromRegistry(node TaskNode) error { +func (r *BackupNodesRegistry) UnregisterNodeFromRegistry(backupNode BackupNode) error { ctx, cancel := context.WithTimeout(context.Background(), r.timeout) defer cancel() - infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node.ID.String(), nodeInfoKeySuffix) + infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, backupNode.ID.String(), nodeInfoKeySuffix) counterKey := fmt.Sprintf( "%s%s%s", - nodeActiveTasksPrefix, - node.ID.String(), - nodeActiveTasksSuffix, + nodeActiveBackupsPrefix, + backupNode.ID.String(), + nodeActiveBackupsSuffix, ) result := r.client.Do( @@ -331,49 +328,49 @@ func (r *TaskNodesRegistry) UnregisterNodeFromRegistry(node TaskNode) error { ) if result.Error() != nil { - return fmt.Errorf("failed to unregister node %s: %w", node.ID, result.Error()) + return fmt.Errorf("failed to unregister node %s: %w", backupNode.ID, result.Error()) } - r.logger.Info("Unregistered node from registry", "nodeID", node.ID) + r.logger.Info("Unregistered node from registry", "nodeID", backupNode.ID) return nil } -func (r *TaskNodesRegistry) AssignTaskToNode( - targetNodeID string, - taskID uuid.UUID, +func (r *BackupNodesRegistry) AssignBackupToNode( + targetNodeID uuid.UUID, + backupID uuid.UUID, isCallNotifier bool, ) error { ctx := context.Background() - message := TaskSubmitMessage{ + message := BackupSubmitMessage{ NodeID: targetNodeID, - TaskID: taskID.String(), + BackupID: backupID, IsCallNotifier: isCallNotifier, } messageJSON, err := json.Marshal(message) if err != nil { - return fmt.Errorf("failed to marshal task submit message: %w", err) + return fmt.Errorf("failed to marshal backup submit message: %w", err) } - err = r.pubsubTasks.Publish(ctx, taskSubmitChannel, string(messageJSON)) + err = r.pubsubBackups.Publish(ctx, backupSubmitChannel, string(messageJSON)) if err != nil { - return fmt.Errorf("failed to publish task submit message: %w", err) + return fmt.Errorf("failed to publish backup submit message: %w", err) } return nil } -func (r *TaskNodesRegistry) SubscribeNodeForTasksAssignment( - nodeID string, - handler func(taskID uuid.UUID, isCallNotifier bool), +func (r *BackupNodesRegistry) SubscribeNodeForBackupsAssignment( + nodeID uuid.UUID, + handler func(backupID uuid.UUID, isCallNotifier bool), ) error { ctx := context.Background() wrappedHandler := func(message string) { - var msg TaskSubmitMessage + var msg BackupSubmitMessage if err := json.Unmarshal([]byte(message), &msg); err != nil { - r.logger.Warn("Failed to unmarshal task submit message", "error", err) + r.logger.Warn("Failed to unmarshal backup submit message", "error", err) return } @@ -381,108 +378,84 @@ func (r *TaskNodesRegistry) SubscribeNodeForTasksAssignment( return } - taskID, err := uuid.Parse(msg.TaskID) - if err != nil { - r.logger.Warn( - "Failed to parse task ID from message", - "taskId", - msg.TaskID, - "error", - err, - ) - return - } - - handler(taskID, msg.IsCallNotifier) + handler(msg.BackupID, msg.IsCallNotifier) } - err := r.pubsubTasks.Subscribe(ctx, taskSubmitChannel, wrappedHandler) + err := r.pubsubBackups.Subscribe(ctx, backupSubmitChannel, wrappedHandler) if err != nil { - return fmt.Errorf("failed to subscribe to task submit channel: %w", err) + return fmt.Errorf("failed to subscribe to backup submit channel: %w", err) } - r.logger.Info("Subscribed to task submit channel", "nodeID", nodeID) + r.logger.Info("Subscribed to backup submit channel", "nodeID", nodeID) return nil } -func (r *TaskNodesRegistry) UnsubscribeNodeForTasksAssignments() error { - err := r.pubsubTasks.Close() +func (r *BackupNodesRegistry) UnsubscribeNodeForBackupsAssignments() error { + err := r.pubsubBackups.Close() if err != nil { - return fmt.Errorf("failed to unsubscribe from task submit channel: %w", err) + return fmt.Errorf("failed to unsubscribe from backup submit channel: %w", err) } - r.logger.Info("Unsubscribed from task submit channel") + r.logger.Info("Unsubscribed from backup submit channel") return nil } -func (r *TaskNodesRegistry) PublishTaskCompletion(nodeID string, taskID uuid.UUID) error { +func (r *BackupNodesRegistry) PublishBackupCompletion(nodeID uuid.UUID, backupID uuid.UUID) error { ctx := context.Background() - message := TaskCompletionMessage{ - NodeID: nodeID, - TaskID: taskID.String(), + message := BackupCompletionMessage{ + NodeID: nodeID, + BackupID: backupID, } messageJSON, err := json.Marshal(message) if err != nil { - return fmt.Errorf("failed to marshal task completion message: %w", err) + return fmt.Errorf("failed to marshal backup completion message: %w", err) } - err = r.pubsubCompletions.Publish(ctx, taskCompletionChannel, string(messageJSON)) + err = r.pubsubCompletions.Publish(ctx, backupCompletionChannel, string(messageJSON)) if err != nil { - return fmt.Errorf("failed to publish task completion message: %w", err) + return fmt.Errorf("failed to publish backup completion message: %w", err) } return nil } -func (r *TaskNodesRegistry) SubscribeForTasksCompletions( - handler func(nodeID string, taskID uuid.UUID), +func (r *BackupNodesRegistry) SubscribeForBackupsCompletions( + handler func(nodeID uuid.UUID, backupID uuid.UUID), ) error { ctx := context.Background() wrappedHandler := func(message string) { - var msg TaskCompletionMessage + var msg BackupCompletionMessage if err := json.Unmarshal([]byte(message), &msg); err != nil { - r.logger.Warn("Failed to unmarshal task completion message", "error", err) + r.logger.Warn("Failed to unmarshal backup completion message", "error", err) return } - taskID, err := uuid.Parse(msg.TaskID) - if err != nil { - r.logger.Warn( - "Failed to parse task ID from completion message", - "taskId", - msg.TaskID, - "error", - err, - ) - return - } - - handler(msg.NodeID, taskID) + handler(msg.NodeID, msg.BackupID) } - err := r.pubsubCompletions.Subscribe(ctx, taskCompletionChannel, wrappedHandler) + err := r.pubsubCompletions.Subscribe(ctx, backupCompletionChannel, wrappedHandler) if err != nil { - return fmt.Errorf("failed to subscribe to task completion channel: %w", err) + return fmt.Errorf("failed to subscribe to backup completion channel: %w", err) } - r.logger.Info("Subscribed to task completion channel") + r.logger.Info("Subscribed to backup completion channel") return nil } -func (r *TaskNodesRegistry) UnsubscribeForTasksCompletions() error { +func (r *BackupNodesRegistry) UnsubscribeForBackupsCompletions() error { err := r.pubsubCompletions.Close() if err != nil { - return fmt.Errorf("failed to unsubscribe from task completion channel: %w", err) + return fmt.Errorf("failed to unsubscribe from backup completion channel: %w", err) } - r.logger.Info("Unsubscribed from task completion channel") + r.logger.Info("Unsubscribed from backup completion channel") return nil } -func (r *TaskNodesRegistry) extractNodeIDFromKey(key, prefix, suffix string) uuid.UUID { +func (r *BackupNodesRegistry) extractNodeIDFromKey(key, prefix, suffix string) uuid.UUID { nodeIDStr := strings.TrimPrefix(key, prefix) nodeIDStr = strings.TrimSuffix(nodeIDStr, suffix) @@ -495,7 +468,7 @@ func (r *TaskNodesRegistry) extractNodeIDFromKey(key, prefix, suffix string) uui return nodeID } -func (r *TaskNodesRegistry) pipelineGetKeys(keys []string) (map[string][]byte, error) { +func (r *BackupNodesRegistry) pipelineGetKeys(keys []string) (map[string][]byte, error) { if len(keys) == 0 { return make(map[string][]byte), nil } @@ -529,7 +502,7 @@ func (r *TaskNodesRegistry) pipelineGetKeys(keys []string) (map[string][]byte, e return keyDataMap, nil } -func (r *TaskNodesRegistry) parseIntFromBytes(data []byte) (int64, error) { +func (r *BackupNodesRegistry) parseIntFromBytes(data []byte) (int64, error) { str := string(data) var count int64 _, err := fmt.Sscanf(str, "%d", &count) @@ -539,7 +512,7 @@ func (r *TaskNodesRegistry) parseIntFromBytes(data []byte) (int64, error) { return count, nil } -func (r *TaskNodesRegistry) cleanupDeadNodes() error { +func (r *BackupNodesRegistry) cleanupDeadNodes() error { ctx, cancel := context.WithTimeout(context.Background(), r.timeout) defer cancel() @@ -583,13 +556,12 @@ func (r *TaskNodesRegistry) cleanupDeadNodes() error { var deadNodeKeys []string for key, data := range keyDataMap { - // Skip if the key doesn't exist (data is empty) if len(data) == 0 { continue } - var node TaskNode + var node BackupNode if err := json.Unmarshal(data, &node); err != nil { r.logger.Warn("Failed to unmarshal node data during cleanup", "key", key, "error", err) continue @@ -603,7 +575,12 @@ func (r *TaskNodesRegistry) cleanupDeadNodes() error { if node.LastHeartbeat.Before(threshold) { nodeID := node.ID.String() infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeID, nodeInfoKeySuffix) - statsKey := fmt.Sprintf("%s%s%s", nodeActiveTasksPrefix, nodeID, nodeActiveTasksSuffix) + statsKey := fmt.Sprintf( + "%s%s%s", + nodeActiveBackupsPrefix, + nodeID, + nodeActiveBackupsSuffix, + ) deadNodeKeys = append(deadNodeKeys, infoKey, statsKey) r.logger.Info( diff --git a/backend/internal/features/backups/backups/backuping/registry_test.go b/backend/internal/features/backups/backups/backuping/registry_test.go new file mode 100644 index 0000000..a2f382b --- /dev/null +++ b/backend/internal/features/backups/backups/backuping/registry_test.go @@ -0,0 +1,1130 @@ +package backuping + +import ( + "context" + "encoding/json" + "fmt" + "testing" + "time" + + cache_utils "databasus-backend/internal/util/cache" + "databasus-backend/internal/util/logger" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +func Test_HearthbeatNodeInRegistry_RegistersNodeWithTTL(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node := createTestBackupNode() + defer cleanupTestNode(registry, node) + + err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node) + assert.NoError(t, err) + + nodes, err := registry.GetAvailableNodes() + assert.NoError(t, err) + assert.Len(t, nodes, 1) + assert.Equal(t, node.ID, nodes[0].ID) + assert.Equal(t, node.ThroughputMBs, nodes[0].ThroughputMBs) +} + +func Test_UnregisterNodeFromRegistry_RemovesNodeAndCounter(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node := createTestBackupNode() + + err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node) + assert.NoError(t, err) + + err = registry.IncrementBackupsInProgress(node.ID) + assert.NoError(t, err) + + err = registry.UnregisterNodeFromRegistry(node) + assert.NoError(t, err) + + nodes, err := registry.GetAvailableNodes() + assert.NoError(t, err) + assert.Empty(t, nodes) + + stats, err := registry.GetBackupNodesStats() + assert.NoError(t, err) + assert.Empty(t, stats) +} + +func Test_GetAvailableNodes_ReturnsAllRegisteredNodes(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node1 := createTestBackupNode() + node2 := createTestBackupNode() + node3 := createTestBackupNode() + defer cleanupTestNode(registry, node1) + defer cleanupTestNode(registry, node2) + defer cleanupTestNode(registry, node3) + + err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1) + assert.NoError(t, err) + err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2) + assert.NoError(t, err) + err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3) + assert.NoError(t, err) + + nodes, err := registry.GetAvailableNodes() + assert.NoError(t, err) + assert.Len(t, nodes, 3) + + nodeIDs := make(map[uuid.UUID]bool) + for _, node := range nodes { + nodeIDs[node.ID] = true + } + assert.True(t, nodeIDs[node1.ID]) + assert.True(t, nodeIDs[node2.ID]) + assert.True(t, nodeIDs[node3.ID]) +} + +func Test_GetAvailableNodes_WhenNoNodesExist_ReturnsEmptySlice(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + + nodes, err := registry.GetAvailableNodes() + assert.NoError(t, err) + assert.NotNil(t, nodes) + assert.Empty(t, nodes) +} + +func Test_IncrementBackupsInProgress_IncrementsCounter(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node := createTestBackupNode() + defer cleanupTestNode(registry, node) + + err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node) + assert.NoError(t, err) + + err = registry.IncrementBackupsInProgress(node.ID) + assert.NoError(t, err) + + stats, err := registry.GetBackupNodesStats() + assert.NoError(t, err) + assert.Len(t, stats, 1) + assert.Equal(t, node.ID, stats[0].ID) + assert.Equal(t, 1, stats[0].ActiveBackups) + + err = registry.IncrementBackupsInProgress(node.ID) + assert.NoError(t, err) + + stats, err = registry.GetBackupNodesStats() + assert.NoError(t, err) + assert.Len(t, stats, 1) + assert.Equal(t, 2, stats[0].ActiveBackups) +} + +func Test_DecrementBackupsInProgress_DecrementsCounter(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node := createTestBackupNode() + defer cleanupTestNode(registry, node) + + err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node) + assert.NoError(t, err) + + err = registry.IncrementBackupsInProgress(node.ID) + assert.NoError(t, err) + err = registry.IncrementBackupsInProgress(node.ID) + assert.NoError(t, err) + err = registry.IncrementBackupsInProgress(node.ID) + assert.NoError(t, err) + + stats, err := registry.GetBackupNodesStats() + assert.NoError(t, err) + assert.Equal(t, 3, stats[0].ActiveBackups) + + err = registry.DecrementBackupsInProgress(node.ID) + assert.NoError(t, err) + + stats, err = registry.GetBackupNodesStats() + assert.NoError(t, err) + assert.Equal(t, 2, stats[0].ActiveBackups) + + err = registry.DecrementBackupsInProgress(node.ID) + assert.NoError(t, err) + + stats, err = registry.GetBackupNodesStats() + assert.NoError(t, err) + assert.Equal(t, 1, stats[0].ActiveBackups) +} + +func Test_DecrementBackupsInProgress_WhenNegative_ResetsToZero(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node := createTestBackupNode() + defer cleanupTestNode(registry, node) + + err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node) + assert.NoError(t, err) + + err = registry.DecrementBackupsInProgress(node.ID) + assert.NoError(t, err) + + stats, err := registry.GetBackupNodesStats() + assert.NoError(t, err) + assert.Len(t, stats, 1) + assert.Equal(t, 0, stats[0].ActiveBackups) +} + +func Test_GetBackupNodesStats_ReturnsStatsForAllNodes(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node1 := createTestBackupNode() + node2 := createTestBackupNode() + node3 := createTestBackupNode() + defer cleanupTestNode(registry, node1) + defer cleanupTestNode(registry, node2) + defer cleanupTestNode(registry, node3) + + err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1) + assert.NoError(t, err) + err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2) + assert.NoError(t, err) + err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3) + assert.NoError(t, err) + + err = registry.IncrementBackupsInProgress(node1.ID) + assert.NoError(t, err) + + err = registry.IncrementBackupsInProgress(node2.ID) + assert.NoError(t, err) + err = registry.IncrementBackupsInProgress(node2.ID) + assert.NoError(t, err) + + err = registry.IncrementBackupsInProgress(node3.ID) + assert.NoError(t, err) + err = registry.IncrementBackupsInProgress(node3.ID) + assert.NoError(t, err) + err = registry.IncrementBackupsInProgress(node3.ID) + assert.NoError(t, err) + + stats, err := registry.GetBackupNodesStats() + assert.NoError(t, err) + assert.Len(t, stats, 3) + + statsMap := make(map[uuid.UUID]int) + for _, stat := range stats { + statsMap[stat.ID] = stat.ActiveBackups + } + + assert.Equal(t, 1, statsMap[node1.ID]) + assert.Equal(t, 2, statsMap[node2.ID]) + assert.Equal(t, 3, statsMap[node3.ID]) +} + +func Test_GetBackupNodesStats_WhenNoStats_ReturnsEmptySlice(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + + stats, err := registry.GetBackupNodesStats() + assert.NoError(t, err) + assert.NotNil(t, stats) + assert.Empty(t, stats) +} + +func Test_MultipleNodes_RegisteredAndQueriedCorrectly(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node1 := createTestBackupNode() + node1.ThroughputMBs = 50 + node2 := createTestBackupNode() + node2.ThroughputMBs = 100 + node3 := createTestBackupNode() + node3.ThroughputMBs = 150 + defer cleanupTestNode(registry, node1) + defer cleanupTestNode(registry, node2) + defer cleanupTestNode(registry, node3) + + err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1) + assert.NoError(t, err) + err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2) + assert.NoError(t, err) + err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3) + assert.NoError(t, err) + + nodes, err := registry.GetAvailableNodes() + assert.NoError(t, err) + assert.Len(t, nodes, 3) + + nodeMap := make(map[uuid.UUID]BackupNode) + for _, node := range nodes { + nodeMap[node.ID] = node + } + + assert.Equal(t, 50, nodeMap[node1.ID].ThroughputMBs) + assert.Equal(t, 100, nodeMap[node2.ID].ThroughputMBs) + assert.Equal(t, 150, nodeMap[node3.ID].ThroughputMBs) +} + +func Test_BackupCounters_TrackedSeparatelyPerNode(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node1 := createTestBackupNode() + node2 := createTestBackupNode() + defer cleanupTestNode(registry, node1) + defer cleanupTestNode(registry, node2) + + err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1) + assert.NoError(t, err) + err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2) + assert.NoError(t, err) + + err = registry.IncrementBackupsInProgress(node1.ID) + assert.NoError(t, err) + err = registry.IncrementBackupsInProgress(node1.ID) + assert.NoError(t, err) + + err = registry.IncrementBackupsInProgress(node2.ID) + assert.NoError(t, err) + + stats, err := registry.GetBackupNodesStats() + assert.NoError(t, err) + assert.Len(t, stats, 2) + + statsMap := make(map[uuid.UUID]int) + for _, stat := range stats { + statsMap[stat.ID] = stat.ActiveBackups + } + + assert.Equal(t, 2, statsMap[node1.ID]) + assert.Equal(t, 1, statsMap[node2.ID]) + + err = registry.DecrementBackupsInProgress(node1.ID) + assert.NoError(t, err) + + stats, err = registry.GetBackupNodesStats() + assert.NoError(t, err) + + statsMap = make(map[uuid.UUID]int) + for _, stat := range stats { + statsMap[stat.ID] = stat.ActiveBackups + } + + assert.Equal(t, 1, statsMap[node1.ID]) + assert.Equal(t, 1, statsMap[node2.ID]) +} + +func Test_GetAvailableNodes_SkipsInvalidJsonData(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node := createTestBackupNode() + defer cleanupTestNode(registry, node) + + err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node) + assert.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), registry.timeout) + defer cancel() + + invalidKey := nodeInfoKeyPrefix + uuid.New().String() + nodeInfoKeySuffix + registry.client.Do( + ctx, + registry.client.B().Set().Key(invalidKey).Value("invalid json data").Build(), + ) + defer func() { + cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), registry.timeout) + defer cleanupCancel() + registry.client.Do(cleanupCtx, registry.client.B().Del().Key(invalidKey).Build()) + }() + + nodes, err := registry.GetAvailableNodes() + assert.NoError(t, err) + assert.Len(t, nodes, 1) + assert.Equal(t, node.ID, nodes[0].ID) +} + +func Test_PipelineGetKeys_HandlesEmptyKeysList(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + + keyDataMap, err := registry.pipelineGetKeys([]string{}) + assert.NoError(t, err) + assert.NotNil(t, keyDataMap) + assert.Empty(t, keyDataMap) +} + +func Test_HearthbeatNodeInRegistry_UpdatesLastHeartbeat(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node := createTestBackupNode() + originalHeartbeat := node.LastHeartbeat + defer cleanupTestNode(registry, node) + + time.Sleep(10 * time.Millisecond) + + err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node) + assert.NoError(t, err) + + nodes, err := registry.GetAvailableNodes() + assert.NoError(t, err) + assert.Len(t, nodes, 1) + assert.True(t, nodes[0].LastHeartbeat.After(originalHeartbeat)) +} + +func Test_HearthbeatNodeInRegistry_RejectsZeroTimestamp(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node := createTestBackupNode() + + err := registry.HearthbeatNodeInRegistry(time.Time{}, node) + assert.Error(t, err) + assert.Contains(t, err.Error(), "zero heartbeat timestamp") + + nodes, err := registry.GetAvailableNodes() + assert.NoError(t, err) + assert.Len(t, nodes, 0) +} + +func Test_GetAvailableNodes_ExcludesStaleNodesFromCache(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node1 := createTestBackupNode() + node2 := createTestBackupNode() + node3 := createTestBackupNode() + defer cleanupTestNode(registry, node1) + defer cleanupTestNode(registry, node2) + defer cleanupTestNode(registry, node3) + + err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1) + assert.NoError(t, err) + err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2) + assert.NoError(t, err) + err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3) + assert.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), registry.timeout) + defer cancel() + + key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix) + result := registry.client.Do(ctx, registry.client.B().Get().Key(key).Build()) + assert.NoError(t, result.Error()) + + data, err := result.AsBytes() + assert.NoError(t, err) + + var node BackupNode + err = json.Unmarshal(data, &node) + assert.NoError(t, err) + + node.LastHeartbeat = time.Now().UTC().Add(-3 * time.Minute) + modifiedData, err := json.Marshal(node) + assert.NoError(t, err) + + setCtx, setCancel := context.WithTimeout(context.Background(), registry.timeout) + defer setCancel() + setResult := registry.client.Do( + setCtx, + registry.client.B().Set().Key(key).Value(string(modifiedData)).Build(), + ) + assert.NoError(t, setResult.Error()) + + nodes, err := registry.GetAvailableNodes() + assert.NoError(t, err) + assert.Len(t, nodes, 2) + + nodeIDs := make(map[uuid.UUID]bool) + for _, n := range nodes { + nodeIDs[n.ID] = true + } + assert.True(t, nodeIDs[node1.ID]) + assert.False(t, nodeIDs[node2.ID]) + assert.True(t, nodeIDs[node3.ID]) +} + +func Test_GetBackupNodesStats_ExcludesStaleNodesFromCache(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node1 := createTestBackupNode() + node2 := createTestBackupNode() + node3 := createTestBackupNode() + defer cleanupTestNode(registry, node1) + defer cleanupTestNode(registry, node2) + defer cleanupTestNode(registry, node3) + + err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1) + assert.NoError(t, err) + err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2) + assert.NoError(t, err) + err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3) + assert.NoError(t, err) + + err = registry.IncrementBackupsInProgress(node1.ID) + assert.NoError(t, err) + err = registry.IncrementBackupsInProgress(node2.ID) + assert.NoError(t, err) + err = registry.IncrementBackupsInProgress(node3.ID) + assert.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), registry.timeout) + defer cancel() + + key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix) + result := registry.client.Do(ctx, registry.client.B().Get().Key(key).Build()) + assert.NoError(t, result.Error()) + + data, err := result.AsBytes() + assert.NoError(t, err) + + var node BackupNode + err = json.Unmarshal(data, &node) + assert.NoError(t, err) + + node.LastHeartbeat = time.Now().UTC().Add(-3 * time.Minute) + modifiedData, err := json.Marshal(node) + assert.NoError(t, err) + + setCtx, setCancel := context.WithTimeout(context.Background(), registry.timeout) + defer setCancel() + setResult := registry.client.Do( + setCtx, + registry.client.B().Set().Key(key).Value(string(modifiedData)).Build(), + ) + assert.NoError(t, setResult.Error()) + + stats, err := registry.GetBackupNodesStats() + assert.NoError(t, err) + assert.Len(t, stats, 2) + + statsMap := make(map[uuid.UUID]int) + for _, stat := range stats { + statsMap[stat.ID] = stat.ActiveBackups + } + + assert.Equal(t, 1, statsMap[node1.ID]) + _, hasNode2 := statsMap[node2.ID] + assert.False(t, hasNode2) + assert.Equal(t, 1, statsMap[node3.ID]) +} + +func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) { + cache_utils.ClearAllCache() + + registry := createTestRegistry() + node1 := createTestBackupNode() + node2 := createTestBackupNode() + defer cleanupTestNode(registry, node1) + defer cleanupTestNode(registry, node2) + + err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1) + assert.NoError(t, err) + err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2) + assert.NoError(t, err) + + err = registry.IncrementBackupsInProgress(node1.ID) + assert.NoError(t, err) + err = registry.IncrementBackupsInProgress(node2.ID) + assert.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), registry.timeout) + defer cancel() + + key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix) + result := registry.client.Do(ctx, registry.client.B().Get().Key(key).Build()) + assert.NoError(t, result.Error()) + + data, err := result.AsBytes() + assert.NoError(t, err) + + var node BackupNode + err = json.Unmarshal(data, &node) + assert.NoError(t, err) + + node.LastHeartbeat = time.Now().UTC().Add(-3 * time.Minute) + modifiedData, err := json.Marshal(node) + assert.NoError(t, err) + + setCtx, setCancel := context.WithTimeout(context.Background(), registry.timeout) + defer setCancel() + setResult := registry.client.Do( + setCtx, + registry.client.B().Set().Key(key).Value(string(modifiedData)).Build(), + ) + assert.NoError(t, setResult.Error()) + + err = registry.cleanupDeadNodes() + assert.NoError(t, err) + + checkCtx, checkCancel := context.WithTimeout(context.Background(), registry.timeout) + defer checkCancel() + + infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix) + infoResult := registry.client.Do(checkCtx, registry.client.B().Get().Key(infoKey).Build()) + assert.Error(t, infoResult.Error()) + + counterKey := fmt.Sprintf( + "%s%s%s", + nodeActiveBackupsPrefix, + node2.ID.String(), + nodeActiveBackupsSuffix, + ) + counterCtx, counterCancel := context.WithTimeout(context.Background(), registry.timeout) + defer counterCancel() + counterResult := registry.client.Do( + counterCtx, + registry.client.B().Get().Key(counterKey).Build(), + ) + assert.Error(t, counterResult.Error()) + + activeInfoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node1.ID.String(), nodeInfoKeySuffix) + activeCtx, activeCancel := context.WithTimeout(context.Background(), registry.timeout) + defer activeCancel() + activeResult := registry.client.Do( + activeCtx, + registry.client.B().Get().Key(activeInfoKey).Build(), + ) + assert.NoError(t, activeResult.Error()) + + nodes, err := registry.GetAvailableNodes() + assert.NoError(t, err) + assert.Len(t, nodes, 1) + assert.Equal(t, node1.ID, nodes[0].ID) + + stats, err := registry.GetBackupNodesStats() + assert.NoError(t, err) + assert.Len(t, stats, 1) + assert.Equal(t, node1.ID, stats[0].ID) +} + +func createTestRegistry() *BackupNodesRegistry { + return &BackupNodesRegistry{ + cache_utils.GetValkeyClient(), + logger.GetLogger(), + cache_utils.DefaultCacheTimeout, + cache_utils.NewPubSubManager(), + cache_utils.NewPubSubManager(), + } +} + +func createTestBackupNode() BackupNode { + return BackupNode{ + ID: uuid.New(), + ThroughputMBs: 100, + LastHeartbeat: time.Now().UTC(), + } +} + +func cleanupTestNode(registry *BackupNodesRegistry, node BackupNode) { + registry.UnregisterNodeFromRegistry(node) +} + +func Test_AssignBackupTonode_PublishesJsonMessageToChannel(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node := createTestBackupNode() + backupID := uuid.New() + + err := registry.AssignBackupToNode(node.ID, backupID, true) + assert.NoError(t, err) +} + +func Test_SubscribeNodeForBackupsAssignment_ReceivesSubmittedBackupsForMatchingNode(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node := createTestBackupNode() + backupID := uuid.New() + defer registry.UnsubscribeNodeForBackupsAssignments() + + receivedBackupID := make(chan uuid.UUID, 1) + handler := func(id uuid.UUID, isCallNotifier bool) { + receivedBackupID <- id + } + + err := registry.SubscribeNodeForBackupsAssignment(node.ID, handler) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + err = registry.AssignBackupToNode(node.ID, backupID, true) + assert.NoError(t, err) + + select { + case received := <-receivedBackupID: + assert.Equal(t, backupID, received) + case <-time.After(2 * time.Second): + t.Fatal("Timeout waiting for backup message") + } +} + +func Test_SubscribeNodeForBackupsAssignment_FiltersOutBackupsForDifferentNode(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node1 := createTestBackupNode() + node2 := createTestBackupNode() + backupID := uuid.New() + defer registry.UnsubscribeNodeForBackupsAssignments() + + receivedBackupID := make(chan uuid.UUID, 1) + handler := func(id uuid.UUID, isCallNotifier bool) { + receivedBackupID <- id + } + + err := registry.SubscribeNodeForBackupsAssignment(node1.ID, handler) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + err = registry.AssignBackupToNode(node2.ID, backupID, false) + assert.NoError(t, err) + + select { + case <-receivedBackupID: + t.Fatal("Should not receive backup for different node") + case <-time.After(500 * time.Millisecond): + } +} + +func Test_SubscribeNodeForBackupsAssignment_ParsesJsonAndBackupIdCorrectly(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node := createTestBackupNode() + backupID1 := uuid.New() + backupID2 := uuid.New() + defer registry.UnsubscribeNodeForBackupsAssignments() + + receivedBackups := make(chan uuid.UUID, 2) + handler := func(id uuid.UUID, isCallNotifier bool) { + receivedBackups <- id + } + + err := registry.SubscribeNodeForBackupsAssignment(node.ID, handler) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + err = registry.AssignBackupToNode(node.ID, backupID1, true) + assert.NoError(t, err) + + err = registry.AssignBackupToNode(node.ID, backupID2, false) + assert.NoError(t, err) + + received1 := <-receivedBackups + received2 := <-receivedBackups + + receivedIDs := []uuid.UUID{received1, received2} + assert.Contains(t, receivedIDs, backupID1) + assert.Contains(t, receivedIDs, backupID2) +} + +func Test_SubscribeNodeForBackupsAssignment_HandlesInvalidJson(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node := createTestBackupNode() + defer registry.UnsubscribeNodeForBackupsAssignments() + + receivedBackupID := make(chan uuid.UUID, 1) + handler := func(id uuid.UUID, isCallNotifier bool) { + receivedBackupID <- id + } + + err := registry.SubscribeNodeForBackupsAssignment(node.ID, handler) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + ctx := context.Background() + err = registry.pubsubBackups.Publish(ctx, "backup:submit", "invalid json") + assert.NoError(t, err) + + select { + case <-receivedBackupID: + t.Fatal("Should not receive backup for invalid JSON") + case <-time.After(500 * time.Millisecond): + } +} + +func Test_UnsubscribeNodeForBackupsAssignments_StopsReceivingMessages(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node := createTestBackupNode() + backupID1 := uuid.New() + backupID2 := uuid.New() + + receivedBackupID := make(chan uuid.UUID, 2) + handler := func(id uuid.UUID, isCallNotifier bool) { + receivedBackupID <- id + } + + err := registry.SubscribeNodeForBackupsAssignment(node.ID, handler) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + err = registry.AssignBackupToNode(node.ID, backupID1, true) + assert.NoError(t, err) + + received := <-receivedBackupID + assert.Equal(t, backupID1, received) + + err = registry.UnsubscribeNodeForBackupsAssignments() + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + err = registry.AssignBackupToNode(node.ID, backupID2, false) + assert.NoError(t, err) + + select { + case <-receivedBackupID: + t.Fatal("Should not receive backup after unsubscribe") + case <-time.After(500 * time.Millisecond): + } +} + +func Test_SubscribeNodeForBackupsAssignment_WhenAlreadySubscribed_ReturnsError(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node := createTestBackupNode() + defer registry.UnsubscribeNodeForBackupsAssignments() + + handler := func(id uuid.UUID, isCallNotifier bool) {} + + err := registry.SubscribeNodeForBackupsAssignment(node.ID, handler) + assert.NoError(t, err) + + err = registry.SubscribeNodeForBackupsAssignment(node.ID, handler) + assert.Error(t, err) + assert.Contains(t, err.Error(), "already subscribed") +} + +func Test_MultipleNodes_EachReceivesOnlyTheirBackups(t *testing.T) { + cache_utils.ClearAllCache() + registry1 := createTestRegistry() + registry2 := createTestRegistry() + registry3 := createTestRegistry() + + node1 := createTestBackupNode() + node2 := createTestBackupNode() + node3 := createTestBackupNode() + + backupID1 := uuid.New() + backupID2 := uuid.New() + backupID3 := uuid.New() + + defer registry1.UnsubscribeNodeForBackupsAssignments() + defer registry2.UnsubscribeNodeForBackupsAssignments() + defer registry3.UnsubscribeNodeForBackupsAssignments() + + receivedBackups1 := make(chan uuid.UUID, 3) + receivedBackups2 := make(chan uuid.UUID, 3) + receivedBackups3 := make(chan uuid.UUID, 3) + + handler1 := func(id uuid.UUID, isCallNotifier bool) { receivedBackups1 <- id } + handler2 := func(id uuid.UUID, isCallNotifier bool) { receivedBackups2 <- id } + handler3 := func(id uuid.UUID, isCallNotifier bool) { receivedBackups3 <- id } + + err := registry1.SubscribeNodeForBackupsAssignment(node1.ID, handler1) + assert.NoError(t, err) + + err = registry2.SubscribeNodeForBackupsAssignment(node2.ID, handler2) + assert.NoError(t, err) + + err = registry3.SubscribeNodeForBackupsAssignment(node3.ID, handler3) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + submitRegistry := createTestRegistry() + err = submitRegistry.AssignBackupToNode(node1.ID, backupID1, true) + assert.NoError(t, err) + + err = submitRegistry.AssignBackupToNode(node2.ID, backupID2, false) + assert.NoError(t, err) + + err = submitRegistry.AssignBackupToNode(node3.ID, backupID3, true) + assert.NoError(t, err) + + select { + case received := <-receivedBackups1: + assert.Equal(t, backupID1, received) + case <-time.After(2 * time.Second): + t.Fatal("Node 1 timeout waiting for backup message") + } + + select { + case received := <-receivedBackups2: + assert.Equal(t, backupID2, received) + case <-time.After(2 * time.Second): + t.Fatal("Node 2 timeout waiting for backup message") + } + + select { + case received := <-receivedBackups3: + assert.Equal(t, backupID3, received) + case <-time.After(2 * time.Second): + t.Fatal("Node 3 timeout waiting for backup message") + } + + select { + case <-receivedBackups1: + t.Fatal("Node 1 should not receive additional backups") + case <-time.After(300 * time.Millisecond): + } + + select { + case <-receivedBackups2: + t.Fatal("Node 2 should not receive additional backups") + case <-time.After(300 * time.Millisecond): + } + + select { + case <-receivedBackups3: + t.Fatal("Node 3 should not receive additional backups") + case <-time.After(300 * time.Millisecond): + } +} + +func Test_PublishBackupCompletion_PublishesMessageToChannel(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node := createTestBackupNode() + backupID := uuid.New() + + err := registry.PublishBackupCompletion(node.ID, backupID) + assert.NoError(t, err) +} + +func Test_SubscribeForBackupsCompletions_ReceivesCompletedBackups(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node := createTestBackupNode() + backupID := uuid.New() + defer registry.UnsubscribeForBackupsCompletions() + + receivedBackupID := make(chan uuid.UUID, 1) + receivedNodeID := make(chan uuid.UUID, 1) + handler := func(nodeID uuid.UUID, backupID uuid.UUID) { + receivedNodeID <- nodeID + receivedBackupID <- backupID + } + + err := registry.SubscribeForBackupsCompletions(handler) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + err = registry.PublishBackupCompletion(node.ID, backupID) + assert.NoError(t, err) + + select { + case receivedNode := <-receivedNodeID: + assert.Equal(t, node.ID, receivedNode) + case <-time.After(2 * time.Second): + t.Fatal("Timeout waiting for node ID") + } + + select { + case received := <-receivedBackupID: + assert.Equal(t, backupID, received) + case <-time.After(2 * time.Second): + t.Fatal("Timeout waiting for backup completion message") + } +} + +func Test_SubscribeForBackupsCompletions_ParsesJsonCorrectly(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node := createTestBackupNode() + backupID1 := uuid.New() + backupID2 := uuid.New() + defer registry.UnsubscribeForBackupsCompletions() + + receivedBackups := make(chan uuid.UUID, 2) + handler := func(nodeID uuid.UUID, backupID uuid.UUID) { + receivedBackups <- backupID + } + + err := registry.SubscribeForBackupsCompletions(handler) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + err = registry.PublishBackupCompletion(node.ID, backupID1) + assert.NoError(t, err) + + err = registry.PublishBackupCompletion(node.ID, backupID2) + assert.NoError(t, err) + + received1 := <-receivedBackups + received2 := <-receivedBackups + + receivedIDs := []uuid.UUID{received1, received2} + assert.Contains(t, receivedIDs, backupID1) + assert.Contains(t, receivedIDs, backupID2) +} + +func Test_SubscribeForBackupsCompletions_HandlesInvalidJson(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + defer registry.UnsubscribeForBackupsCompletions() + + receivedBackupID := make(chan uuid.UUID, 1) + handler := func(nodeID uuid.UUID, backupID uuid.UUID) { + receivedBackupID <- backupID + } + + err := registry.SubscribeForBackupsCompletions(handler) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + ctx := context.Background() + err = registry.pubsubCompletions.Publish(ctx, "backup:completion", "invalid json") + assert.NoError(t, err) + + select { + case <-receivedBackupID: + t.Fatal("Should not receive backup for invalid JSON") + case <-time.After(500 * time.Millisecond): + } +} + +func Test_UnsubscribeForBackupsCompletions_StopsReceivingMessages(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node := createTestBackupNode() + backupID1 := uuid.New() + backupID2 := uuid.New() + + receivedBackupID := make(chan uuid.UUID, 2) + handler := func(nodeID uuid.UUID, backupID uuid.UUID) { + receivedBackupID <- backupID + } + + err := registry.SubscribeForBackupsCompletions(handler) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + err = registry.PublishBackupCompletion(node.ID, backupID1) + assert.NoError(t, err) + + received := <-receivedBackupID + assert.Equal(t, backupID1, received) + + err = registry.UnsubscribeForBackupsCompletions() + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + err = registry.PublishBackupCompletion(node.ID, backupID2) + assert.NoError(t, err) + + select { + case <-receivedBackupID: + t.Fatal("Should not receive backup after unsubscribe") + case <-time.After(500 * time.Millisecond): + } +} + +func Test_SubscribeForBackupsCompletions_WhenAlreadySubscribed_ReturnsError(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + defer registry.UnsubscribeForBackupsCompletions() + + handler := func(nodeID uuid.UUID, backupID uuid.UUID) {} + + err := registry.SubscribeForBackupsCompletions(handler) + assert.NoError(t, err) + + err = registry.SubscribeForBackupsCompletions(handler) + assert.Error(t, err) + assert.Contains(t, err.Error(), "already subscribed") +} + +func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) { + cache_utils.ClearAllCache() + registry1 := createTestRegistry() + registry2 := createTestRegistry() + registry3 := createTestRegistry() + + node1 := createTestBackupNode() + node2 := createTestBackupNode() + node3 := createTestBackupNode() + + backupID1 := uuid.New() + backupID2 := uuid.New() + backupID3 := uuid.New() + + defer registry1.UnsubscribeForBackupsCompletions() + defer registry2.UnsubscribeForBackupsCompletions() + defer registry3.UnsubscribeForBackupsCompletions() + + receivedBackups1 := make(chan uuid.UUID, 3) + receivedBackups2 := make(chan uuid.UUID, 3) + receivedBackups3 := make(chan uuid.UUID, 3) + + handler1 := func(nodeID uuid.UUID, backupID uuid.UUID) { receivedBackups1 <- backupID } + handler2 := func(nodeID uuid.UUID, backupID uuid.UUID) { receivedBackups2 <- backupID } + handler3 := func(nodeID uuid.UUID, backupID uuid.UUID) { receivedBackups3 <- backupID } + + err := registry1.SubscribeForBackupsCompletions(handler1) + assert.NoError(t, err) + + err = registry2.SubscribeForBackupsCompletions(handler2) + assert.NoError(t, err) + + err = registry3.SubscribeForBackupsCompletions(handler3) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + publishRegistry := createTestRegistry() + err = publishRegistry.PublishBackupCompletion(node1.ID, backupID1) + assert.NoError(t, err) + + err = publishRegistry.PublishBackupCompletion(node2.ID, backupID2) + assert.NoError(t, err) + + err = publishRegistry.PublishBackupCompletion(node3.ID, backupID3) + assert.NoError(t, err) + + receivedAll1 := []uuid.UUID{} + receivedAll2 := []uuid.UUID{} + receivedAll3 := []uuid.UUID{} + + for i := 0; i < 3; i++ { + select { + case received := <-receivedBackups1: + receivedAll1 = append(receivedAll1, received) + case <-time.After(2 * time.Second): + t.Fatal("Subscriber 1 timeout waiting for completion message") + } + } + + for i := 0; i < 3; i++ { + select { + case received := <-receivedBackups2: + receivedAll2 = append(receivedAll2, received) + case <-time.After(2 * time.Second): + t.Fatal("Subscriber 2 timeout waiting for completion message") + } + } + + for i := 0; i < 3; i++ { + select { + case received := <-receivedBackups3: + receivedAll3 = append(receivedAll3, received) + case <-time.After(2 * time.Second): + t.Fatal("Subscriber 3 timeout waiting for completion message") + } + } + + assert.Contains(t, receivedAll1, backupID1) + assert.Contains(t, receivedAll1, backupID2) + assert.Contains(t, receivedAll1, backupID3) + + assert.Contains(t, receivedAll2, backupID1) + assert.Contains(t, receivedAll2, backupID2) + assert.Contains(t, receivedAll2, backupID3) + + assert.Contains(t, receivedAll3, backupID1) + assert.Contains(t, receivedAll3, backupID2) + assert.Contains(t, receivedAll3, backupID3) +} diff --git a/backend/internal/features/backups/backups/backuping/scheduler.go b/backend/internal/features/backups/backups/backuping/scheduler.go index 49e6fd4..a910ba4 100644 --- a/backend/internal/features/backups/backups/backuping/scheduler.go +++ b/backend/internal/features/backups/backups/backuping/scheduler.go @@ -7,7 +7,6 @@ import ( backups_config "databasus-backend/internal/features/backups/config" "databasus-backend/internal/features/storages" task_cancellation "databasus-backend/internal/features/tasks/cancellation" - task_registry "databasus-backend/internal/features/tasks/registry" "databasus-backend/internal/util/encryption" "databasus-backend/internal/util/period" "fmt" @@ -28,7 +27,7 @@ type BackupsScheduler struct { backupConfigService *backups_config.BackupConfigService storageService *storages.StorageService taskCancelManager *task_cancellation.TaskCancelManager - tasksRegistry *task_registry.TaskNodesRegistry + backupNodesRegistry *BackupNodesRegistry lastBackupTime time.Time logger *slog.Logger @@ -50,12 +49,14 @@ func (s *BackupsScheduler) Run(ctx context.Context) { panic(err) } - if err := s.tasksRegistry.SubscribeForTasksCompletions(s.onBackupCompleted); err != nil { + err := s.backupNodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted) + if err != nil { s.logger.Error("Failed to subscribe to backup completions", "error", err) panic(err) } + defer func() { - if err := s.tasksRegistry.UnsubscribeForTasksCompletions(); err != nil { + if err := s.backupNodesRegistry.UnsubscribeForBackupsCompletions(); err != nil { s.logger.Error("Failed to unsubscribe from backup completions", "error", err) } }() @@ -180,7 +181,7 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool return } - if err := s.tasksRegistry.IncrementTasksInProgress(leastBusyNodeID.String()); err != nil { + if err := s.backupNodesRegistry.IncrementBackupsInProgress(*leastBusyNodeID); err != nil { s.logger.Error( "Failed to increment backups in progress", "nodeId", @@ -193,7 +194,7 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool return } - if err := s.tasksRegistry.AssignTaskToNode(leastBusyNodeID.String(), backup.ID, isCallNotifier); err != nil { + if err := s.backupNodesRegistry.AssignBackupToNode(*leastBusyNodeID, backup.ID, isCallNotifier); err != nil { s.logger.Error( "Failed to submit backup", "nodeId", @@ -203,7 +204,7 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool "error", err, ) - if decrementErr := s.tasksRegistry.DecrementTasksInProgress(leastBusyNodeID.String()); decrementErr != nil { + if decrementErr := s.backupNodesRegistry.DecrementBackupsInProgress(*leastBusyNodeID); decrementErr != nil { s.logger.Error( "Failed to decrement backups in progress after submit failure", "nodeId", @@ -398,7 +399,7 @@ func (s *BackupsScheduler) runPendingBackups() error { } func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) { - nodes, err := s.tasksRegistry.GetAvailableNodes() + nodes, err := s.backupNodesRegistry.GetAvailableNodes() if err != nil { return nil, fmt.Errorf("failed to get available nodes: %w", err) } @@ -407,17 +408,17 @@ func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) { return nil, fmt.Errorf("no nodes available") } - stats, err := s.tasksRegistry.GetNodesStats() + stats, err := s.backupNodesRegistry.GetBackupNodesStats() if err != nil { return nil, fmt.Errorf("failed to get backup nodes stats: %w", err) } statsMap := make(map[uuid.UUID]int) for _, stat := range stats { - statsMap[stat.ID] = stat.ActiveTasks + statsMap[stat.ID] = stat.ActiveBackups } - var bestNode *task_registry.TaskNode + var bestNode *BackupNode var bestScore float64 = -1 for i := range nodes { @@ -445,21 +446,9 @@ func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) { return &bestNode.ID, nil } -func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUID) { - nodeID, err := uuid.Parse(nodeIDStr) - if err != nil { - s.logger.Error( - "Failed to parse node ID from completion message", - "nodeId", - nodeIDStr, - "error", - err, - ) - return - } - +func (s *BackupsScheduler) onBackupCompleted(nodeID uuid.UUID, backupID uuid.UUID) { // Verify this task is actually a backup (registry contains multiple task types) - _, err = s.backupRepository.FindByID(backupID) + _, err := s.backupRepository.FindByID(backupID) if err != nil { // Not a backup task, ignore it return @@ -505,7 +494,7 @@ func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUI s.backupToNodeRelations[nodeID] = relation } - if err := s.tasksRegistry.DecrementTasksInProgress(nodeIDStr); err != nil { + if err := s.backupNodesRegistry.DecrementBackupsInProgress(nodeID); err != nil { s.logger.Error( "Failed to decrement backups in progress", "nodeId", @@ -519,7 +508,7 @@ func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUI } func (s *BackupsScheduler) checkDeadNodesAndFailBackups() error { - nodes, err := s.tasksRegistry.GetAvailableNodes() + nodes, err := s.backupNodesRegistry.GetAvailableNodes() if err != nil { return fmt.Errorf("failed to get available nodes: %w", err) } @@ -575,7 +564,7 @@ func (s *BackupsScheduler) checkDeadNodesAndFailBackups() error { continue } - if err := s.tasksRegistry.DecrementTasksInProgress(nodeID.String()); err != nil { + if err := s.backupNodesRegistry.DecrementBackupsInProgress(nodeID); err != nil { s.logger.Error( "Failed to decrement backups in progress for dead node", "nodeId", diff --git a/backend/internal/features/backups/backups/backuping/scheduler_test.go b/backend/internal/features/backups/backups/backuping/scheduler_test.go index d31ab34..ac3fc8a 100644 --- a/backend/internal/features/backups/backups/backuping/scheduler_test.go +++ b/backend/internal/features/backups/backups/backuping/scheduler_test.go @@ -7,7 +7,6 @@ import ( "databasus-backend/internal/features/intervals" "databasus-backend/internal/features/notifiers" "databasus-backend/internal/features/storages" - task_registry "databasus-backend/internal/features/tasks/registry" users_enums "databasus-backend/internal/features/users/enums" users_testing "databasus-backend/internal/features/users/testing" workspaces_testing "databasus-backend/internal/features/workspaces/testing" @@ -466,7 +465,7 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist // Clean up mock node if mockNodeID != uuid.Nil { - nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: mockNodeID}) + backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: mockNodeID}) } cache_utils.ClearAllCache() }() @@ -502,12 +501,12 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status) // Verify Valkey counter was incremented when backup was assigned - stats, err := nodesRegistry.GetNodesStats() + stats, err := backupNodesRegistry.GetBackupNodesStats() assert.NoError(t, err) foundStat := false for _, stat := range stats { if stat.ID == mockNodeID { - assert.Equal(t, 1, stat.ActiveTasks) + assert.Equal(t, 1, stat.ActiveBackups) foundStat = true break } @@ -532,11 +531,11 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist assert.Contains(t, *backups[0].FailMessage, "node unavailability") // Verify Valkey counter was decremented after backup failed - stats, err = nodesRegistry.GetNodesStats() + stats, err = backupNodesRegistry.GetBackupNodesStats() assert.NoError(t, err) for _, stat := range stats { if stat.ID == mockNodeID { - assert.Equal(t, 0, stat.ActiveTasks) + assert.Equal(t, 0, stat.ActiveBackups) } } @@ -569,7 +568,7 @@ func Test_OnBackupCompleted_WhenTaskIsNotBackup_SkipsProcessing(t *testing.T) { // Clean up mock node if mockNodeID != uuid.Nil { - nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: mockNodeID}) + backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: mockNodeID}) } cache_utils.ClearAllCache() }() @@ -605,12 +604,12 @@ func Test_OnBackupCompleted_WhenTaskIsNotBackup_SkipsProcessing(t *testing.T) { assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status) // Get initial state of the registry - initialStats, err := nodesRegistry.GetNodesStats() + initialStats, err := backupNodesRegistry.GetBackupNodesStats() assert.NoError(t, err) var initialActiveTasks int for _, stat := range initialStats { if stat.ID == mockNodeID { - initialActiveTasks = stat.ActiveTasks + initialActiveTasks = stat.ActiveBackups break } } @@ -618,16 +617,16 @@ func Test_OnBackupCompleted_WhenTaskIsNotBackup_SkipsProcessing(t *testing.T) { // Call onBackupCompleted with a random UUID (not a backup ID) nonBackupTaskID := uuid.New() - GetBackupsScheduler().onBackupCompleted(mockNodeID.String(), nonBackupTaskID) + GetBackupsScheduler().onBackupCompleted(mockNodeID, nonBackupTaskID) time.Sleep(100 * time.Millisecond) // Verify: Active tasks counter should remain the same (not decremented) - stats, err := nodesRegistry.GetNodesStats() + stats, err := backupNodesRegistry.GetBackupNodesStats() assert.NoError(t, err) for _, stat := range stats { if stat.ID == mockNodeID { - assert.Equal(t, initialActiveTasks, stat.ActiveTasks, + assert.Equal(t, initialActiveTasks, stat.ActiveBackups, "Active tasks should not change for non-backup task") } } @@ -658,9 +657,9 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) { defer func() { // Clean up all mock nodes - nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: node1ID}) - nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: node2ID}) - nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: node3ID}) + backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node1ID}) + backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node2ID}) + backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node3ID}) cache_utils.ClearAllCache() }() @@ -672,17 +671,17 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) { assert.NoError(t, err) for range 5 { - err = nodesRegistry.IncrementTasksInProgress(node1ID.String()) + err = backupNodesRegistry.IncrementBackupsInProgress(node1ID) assert.NoError(t, err) } for range 2 { - err = nodesRegistry.IncrementTasksInProgress(node2ID.String()) + err = backupNodesRegistry.IncrementBackupsInProgress(node2ID) assert.NoError(t, err) } for range 8 { - err = nodesRegistry.IncrementTasksInProgress(node3ID.String()) + err = backupNodesRegistry.IncrementBackupsInProgress(node3ID) assert.NoError(t, err) } @@ -701,8 +700,8 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) { defer func() { // Clean up all mock nodes - nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: node100MBsID}) - nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: node50MBsID}) + backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node100MBsID}) + backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node50MBsID}) cache_utils.ClearAllCache() }() @@ -712,11 +711,11 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) { assert.NoError(t, err) for range 10 { - err = nodesRegistry.IncrementTasksInProgress(node100MBsID.String()) + err = backupNodesRegistry.IncrementBackupsInProgress(node100MBsID) assert.NoError(t, err) } - err = nodesRegistry.IncrementTasksInProgress(node50MBsID.String()) + err = backupNodesRegistry.IncrementBackupsInProgress(node50MBsID) assert.NoError(t, err) leastBusyNodeID, err := GetBackupsScheduler().calculateLeastBusyNode() @@ -880,12 +879,12 @@ func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T assert.NoError(t, err) // Get initial active task count - stats, err := nodesRegistry.GetNodesStats() + stats, err := backupNodesRegistry.GetBackupNodesStats() assert.NoError(t, err) var initialActiveTasks int for _, stat := range stats { if stat.ID == backuperNode.nodeID { - initialActiveTasks = stat.ActiveTasks + initialActiveTasks = stat.ActiveBackups break } } @@ -913,12 +912,12 @@ func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T assert.True(t, decreased, "Active task count should have decreased after backup completion") // Verify final active task count equals initial count - finalStats, err := nodesRegistry.GetNodesStats() + finalStats, err := backupNodesRegistry.GetBackupNodesStats() assert.NoError(t, err) for _, stat := range finalStats { if stat.ID == backuperNode.nodeID { - t.Logf("Final active tasks: %d", stat.ActiveTasks) - assert.Equal(t, initialActiveTasks, stat.ActiveTasks, + t.Logf("Final active tasks: %d", stat.ActiveBackups) + assert.Equal(t, initialActiveTasks, stat.ActiveBackups, "Active task count should return to initial value after backup completion") break } @@ -982,12 +981,12 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) { assert.NoError(t, err) // Get initial active task count - stats, err := nodesRegistry.GetNodesStats() + stats, err := backupNodesRegistry.GetBackupNodesStats() assert.NoError(t, err) var initialActiveTasks int for _, stat := range stats { if stat.ID == backuperNode.nodeID { - initialActiveTasks = stat.ActiveTasks + initialActiveTasks = stat.ActiveBackups break } } @@ -1019,12 +1018,12 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) { assert.True(t, decreased, "Active task count should have decreased after backup failure") // Verify final active task count equals initial count - finalStats, err := nodesRegistry.GetNodesStats() + finalStats, err := backupNodesRegistry.GetBackupNodesStats() assert.NoError(t, err) for _, stat := range finalStats { if stat.ID == backuperNode.nodeID { - t.Logf("Final active tasks: %d", stat.ActiveTasks) - assert.Equal(t, initialActiveTasks, stat.ActiveTasks, + t.Logf("Final active tasks: %d", stat.ActiveBackups) + assert.Equal(t, initialActiveTasks, stat.ActiveBackups, "Active task count should return to initial value after backup failure") break } diff --git a/backend/internal/features/backups/backups/backuping/testing.go b/backend/internal/features/backups/backups/backuping/testing.go index 197502e..747471a 100644 --- a/backend/internal/features/backups/backups/backuping/testing.go +++ b/backend/internal/features/backups/backups/backuping/testing.go @@ -12,7 +12,6 @@ import ( "databasus-backend/internal/features/databases" "databasus-backend/internal/features/notifiers" "databasus-backend/internal/features/storages" - task_registry "databasus-backend/internal/features/tasks/registry" workspaces_controllers "databasus-backend/internal/features/workspaces/controllers" workspaces_services "databasus-backend/internal/features/workspaces/services" workspaces_testing "databasus-backend/internal/features/workspaces/testing" @@ -44,7 +43,7 @@ func CreateTestBackuperNode() *BackuperNode { storages.GetStorageService(), notifiers.GetNotifierService(), taskCancelManager, - nodesRegistry, + backupNodesRegistry, logger.GetLogger(), usecases.GetCreateBackupUsecase(), uuid.New(), @@ -114,7 +113,7 @@ func StartBackuperNodeForTest(t *testing.T, backuperNode *BackuperNode) context. // Poll registry for node presence instead of fixed sleep deadline := time.Now().UTC().Add(5 * time.Second) for time.Now().UTC().Before(deadline) { - nodes, err := nodesRegistry.GetAvailableNodes() + nodes, err := backupNodesRegistry.GetAvailableNodes() if err == nil { for _, node := range nodes { if node.ID == backuperNode.nodeID { @@ -175,7 +174,7 @@ func StopBackuperNodeForTest(t *testing.T, cancel context.CancelFunc, backuperNo // Wait for node to unregister from registry deadline := time.Now().UTC().Add(2 * time.Second) for time.Now().UTC().Before(deadline) { - nodes, err := nodesRegistry.GetAvailableNodes() + nodes, err := backupNodesRegistry.GetAvailableNodes() if err == nil { found := false for _, node := range nodes { @@ -196,13 +195,13 @@ func StopBackuperNodeForTest(t *testing.T, cancel context.CancelFunc, backuperNo } func CreateMockNodeInRegistry(nodeID uuid.UUID, throughputMBs int, lastHeartbeat time.Time) error { - backupNode := task_registry.TaskNode{ + backupNode := BackupNode{ ID: nodeID, ThroughputMBs: throughputMBs, LastHeartbeat: lastHeartbeat, } - return nodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode) + return backupNodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode) } func UpdateNodeHeartbeatDirectly( @@ -210,17 +209,17 @@ func UpdateNodeHeartbeatDirectly( throughputMBs int, lastHeartbeat time.Time, ) error { - backupNode := task_registry.TaskNode{ + backupNode := BackupNode{ ID: nodeID, ThroughputMBs: throughputMBs, LastHeartbeat: lastHeartbeat, } - return nodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode) + return backupNodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode) } -func GetNodeFromRegistry(nodeID uuid.UUID) (*task_registry.TaskNode, error) { - nodes, err := nodesRegistry.GetAvailableNodes() +func GetNodeFromRegistry(nodeID uuid.UUID) (*BackupNode, error) { + nodes, err := backupNodesRegistry.GetAvailableNodes() if err != nil { return nil, err } @@ -246,7 +245,7 @@ func WaitForActiveTasksDecrease( deadline := time.Now().UTC().Add(timeout) for time.Now().UTC().Before(deadline) { - stats, err := nodesRegistry.GetNodesStats() + stats, err := backupNodesRegistry.GetBackupNodesStats() if err != nil { t.Logf("WaitForActiveTasksDecrease: error getting node stats: %v", err) time.Sleep(500 * time.Millisecond) @@ -257,14 +256,14 @@ func WaitForActiveTasksDecrease( if stat.ID == nodeID { t.Logf( "WaitForActiveTasksDecrease: current active tasks = %d (initial = %d)", - stat.ActiveTasks, + stat.ActiveBackups, initialCount, ) - if stat.ActiveTasks < initialCount { + if stat.ActiveBackups < initialCount { t.Logf( "WaitForActiveTasksDecrease: active tasks decreased from %d to %d", initialCount, - stat.ActiveTasks, + stat.ActiveBackups, ) return true } diff --git a/backend/internal/features/backups/backups/testing.go b/backend/internal/features/backups/backups/testing.go index 90549b0..152a8ea 100644 --- a/backend/internal/features/backups/backups/testing.go +++ b/backend/internal/features/backups/backups/testing.go @@ -75,3 +75,23 @@ func WaitForBackupCompletion( t.Logf("WaitForBackupCompletion: timeout waiting for backup to complete") } + +// CreateTestBackup creates a simple test backup record for testing purposes +func CreateTestBackup(databaseID, storageID uuid.UUID) *backups_core.Backup { + backup := &backups_core.Backup{ + ID: uuid.New(), + DatabaseID: databaseID, + StorageID: storageID, + Status: backups_core.BackupStatusCompleted, + BackupSizeMb: 10.5, + BackupDurationMs: 1000, + CreatedAt: time.Now().UTC(), + } + + repo := &backups_core.BackupRepository{} + if err := repo.Save(backup); err != nil { + panic(err) + } + + return backup +} diff --git a/backend/internal/features/restores/background_service.go b/backend/internal/features/restores/background_service.go deleted file mode 100644 index a58f995..0000000 --- a/backend/internal/features/restores/background_service.go +++ /dev/null @@ -1,38 +0,0 @@ -package restores - -import ( - "context" - "databasus-backend/internal/features/restores/enums" - "log/slog" -) - -type RestoreBackgroundService struct { - restoreRepository *RestoreRepository - logger *slog.Logger -} - -func (s *RestoreBackgroundService) Run(ctx context.Context) { - if err := s.failRestoresInProgress(); err != nil { - s.logger.Error("Failed to fail restores in progress", "error", err) - panic(err) - } -} - -func (s *RestoreBackgroundService) failRestoresInProgress() error { - restoresInProgress, err := s.restoreRepository.FindByStatus(enums.RestoreStatusInProgress) - if err != nil { - return err - } - - for _, restore := range restoresInProgress { - failMessage := "Restore failed due to application restart" - restore.Status = enums.RestoreStatusFailed - restore.FailMessage = &failMessage - - if err := s.restoreRepository.Save(restore); err != nil { - return err - } - } - - return nil -} diff --git a/backend/internal/features/restores/controller.go b/backend/internal/features/restores/controller.go index 2c5ca0c..f74c3af 100644 --- a/backend/internal/features/restores/controller.go +++ b/backend/internal/features/restores/controller.go @@ -1,6 +1,7 @@ package restores import ( + restores_core "databasus-backend/internal/features/restores/core" users_middleware "databasus-backend/internal/features/users/middleware" "net/http" @@ -23,7 +24,7 @@ func (c *RestoreController) RegisterRoutes(router *gin.RouterGroup) { // @Tags restores // @Produce json // @Param backupId path string true "Backup ID" -// @Success 200 {array} models.Restore +// @Success 200 {array} restores_core.Restore // @Failure 400 // @Failure 401 // @Router /restores/{backupId} [get] @@ -71,7 +72,7 @@ func (c *RestoreController) RestoreBackup(ctx *gin.Context) { return } - var requestDTO RestoreBackupRequest + var requestDTO restores_core.RestoreBackupRequest if err := ctx.ShouldBindJSON(&requestDTO); err != nil { ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return diff --git a/backend/internal/features/restores/controller_test.go b/backend/internal/features/restores/controller_test.go index f73ebc9..64b99a0 100644 --- a/backend/internal/features/restores/controller_test.go +++ b/backend/internal/features/restores/controller_test.go @@ -18,20 +18,18 @@ import ( "databasus-backend/internal/config" audit_logs "databasus-backend/internal/features/audit_logs" - "databasus-backend/internal/features/backups/backups" backups_core "databasus-backend/internal/features/backups/backups/core" backups_config "databasus-backend/internal/features/backups/config" "databasus-backend/internal/features/databases" "databasus-backend/internal/features/databases/databases/mysql" "databasus-backend/internal/features/databases/databases/postgresql" - "databasus-backend/internal/features/restores/models" + restores_core "databasus-backend/internal/features/restores/core" "databasus-backend/internal/features/storages" local_storage "databasus-backend/internal/features/storages/models/local" users_dto "databasus-backend/internal/features/users/dto" users_enums "databasus-backend/internal/features/users/enums" users_services "databasus-backend/internal/features/users/services" users_testing "databasus-backend/internal/features/users/testing" - workspaces_controllers "databasus-backend/internal/features/workspaces/controllers" workspaces_models "databasus-backend/internal/features/workspaces/models" workspaces_testing "databasus-backend/internal/features/workspaces/testing" util_encryption "databasus-backend/internal/util/encryption" @@ -46,7 +44,7 @@ func Test_GetRestores_WhenUserIsWorkspaceMember_RestoresReturned(t *testing.T) { database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router) - var restores []*models.Restore + var restores []*restores_core.Restore test_utils.MakeGetRequestAndUnmarshal( t, router, @@ -90,7 +88,7 @@ func Test_GetRestores_WhenUserIsGlobalAdmin_RestoresReturned(t *testing.T) { admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin) - var restores []*models.Restore + var restores []*restores_core.Restore test_utils.MakeGetRequestAndUnmarshal( t, router, @@ -105,12 +103,16 @@ func Test_GetRestores_WhenUserIsGlobalAdmin_RestoresReturned(t *testing.T) { func Test_RestoreBackup_WhenUserIsWorkspaceMember_RestoreInitiated(t *testing.T) { router := createTestRouter() + + _, cleanup := SetupMockRestoreNode(t) + defer cleanup() + owner := users_testing.CreateTestUser(users_enums.UserRoleMember) workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router) _, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router) - request := RestoreBackupRequest{ + request := restores_core.RestoreBackupRequest{ PostgresqlDatabase: &postgresql.PostgresqlDatabase{ Version: tools.PostgresqlVersion16, Host: "localhost", @@ -141,7 +143,7 @@ func Test_RestoreBackup_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember) - request := RestoreBackupRequest{ + request := restores_core.RestoreBackupRequest{ PostgresqlDatabase: &postgresql.PostgresqlDatabase{ Version: tools.PostgresqlVersion16, Host: "localhost", @@ -165,12 +167,16 @@ func Test_RestoreBackup_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing func Test_RestoreBackup_WithIsExcludeExtensions_FlagPassedCorrectly(t *testing.T) { router := createTestRouter() + + _, cleanup := SetupMockRestoreNode(t) + defer cleanup() + owner := users_testing.CreateTestUser(users_enums.UserRoleMember) workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router) _, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router) - request := RestoreBackupRequest{ + request := restores_core.RestoreBackupRequest{ PostgresqlDatabase: &postgresql.PostgresqlDatabase{ Version: tools.PostgresqlVersion16, Host: "localhost", @@ -195,12 +201,16 @@ func Test_RestoreBackup_WithIsExcludeExtensions_FlagPassedCorrectly(t *testing.T func Test_RestoreBackup_AuditLogWritten(t *testing.T) { router := createTestRouter() + + _, cleanup := SetupMockRestoreNode(t) + defer cleanup() + owner := users_testing.CreateTestUser(users_enums.UserRoleMember) workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router) database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router) - request := RestoreBackupRequest{ + request := restores_core.RestoreBackupRequest{ PostgresqlDatabase: &postgresql.PostgresqlDatabase{ Version: tools.PostgresqlVersion16, Host: "localhost", @@ -272,15 +282,22 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { router := createTestRouter() + + // Setup mock node for tests that skip disk validation and reach scheduler + if !tc.expectDiskValidated { + _, cleanup := SetupMockRestoreNode(t) + defer cleanup() + } + owner := users_testing.CreateTestUser(users_enums.UserRoleMember) workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router) var backup *backups_core.Backup - var request RestoreBackupRequest + var request restores_core.RestoreBackupRequest if tc.dbType == databases.DatabaseTypePostgres { _, backup = createTestDatabaseWithBackupForRestore(workspace, owner, router) - request = RestoreBackupRequest{ + request = restores_core.RestoreBackupRequest{ PostgresqlDatabase: &postgresql.PostgresqlDatabase{ Version: tools.PostgresqlVersion16, Host: "localhost", @@ -310,7 +327,7 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) { assert.NoError(t, err) backup = createTestBackup(mysqlDB, owner) - request = RestoreBackupRequest{ + request = restores_core.RestoreBackupRequest{ MysqlDatabase: &mysql.MysqlDatabase{ Version: tools.MysqlVersion80, Host: "localhost", @@ -354,15 +371,7 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) { } func createTestRouter() *gin.Engine { - router := workspaces_testing.CreateTestRouter( - workspaces_controllers.GetWorkspaceController(), - workspaces_controllers.GetMembershipController(), - databases.GetDatabaseController(), - backups_config.GetBackupConfigController(), - backups.GetBackupController(), - GetRestoreController(), - ) - return router + return CreateTestRouter() } func createTestDatabaseWithBackupForRestore( diff --git a/backend/internal/features/restores/dto.go b/backend/internal/features/restores/core/dto.go similarity index 96% rename from backend/internal/features/restores/dto.go rename to backend/internal/features/restores/core/dto.go index 471058a..b6104e9 100644 --- a/backend/internal/features/restores/dto.go +++ b/backend/internal/features/restores/core/dto.go @@ -1,4 +1,4 @@ -package restores +package restores_core import ( "databasus-backend/internal/features/databases/databases/mariadb" diff --git a/backend/internal/features/restores/enums/enums.go b/backend/internal/features/restores/core/enums.go similarity index 89% rename from backend/internal/features/restores/enums/enums.go rename to backend/internal/features/restores/core/enums.go index cffcc64..db1c472 100644 --- a/backend/internal/features/restores/enums/enums.go +++ b/backend/internal/features/restores/core/enums.go @@ -1,4 +1,4 @@ -package enums +package restores_core type RestoreStatus string diff --git a/backend/internal/features/restores/core/interfaces.go b/backend/internal/features/restores/core/interfaces.go new file mode 100644 index 0000000..0445727 --- /dev/null +++ b/backend/internal/features/restores/core/interfaces.go @@ -0,0 +1,20 @@ +package restores_core + +import ( + backups_core "databasus-backend/internal/features/backups/backups/core" + backups_config "databasus-backend/internal/features/backups/config" + "databasus-backend/internal/features/databases" + "databasus-backend/internal/features/storages" +) + +type RestoreBackupUsecase interface { + Execute( + backupConfig *backups_config.BackupConfig, + restore Restore, + originalDB *databases.Database, + restoringToDB *databases.Database, + backup *backups_core.Backup, + storage *storages.Storage, + isExcludeExtensions bool, + ) error +} diff --git a/backend/internal/features/restores/core/model.go b/backend/internal/features/restores/core/model.go new file mode 100644 index 0000000..d3d5d00 --- /dev/null +++ b/backend/internal/features/restores/core/model.go @@ -0,0 +1,30 @@ +package restores_core + +import ( + backups_core "databasus-backend/internal/features/backups/backups/core" + "databasus-backend/internal/features/databases/databases/mariadb" + "databasus-backend/internal/features/databases/databases/mongodb" + "databasus-backend/internal/features/databases/databases/mysql" + "databasus-backend/internal/features/databases/databases/postgresql" + "time" + + "github.com/google/uuid" +) + +type Restore struct { + ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"` + Status RestoreStatus `json:"status" gorm:"column:status;type:text;not null"` + + BackupID uuid.UUID `json:"backupId" gorm:"column:backup_id;type:uuid;not null"` + Backup *backups_core.Backup + + PostgresqlDatabase *postgresql.PostgresqlDatabase `json:"postgresqlDatabase" gorm:"-"` + MysqlDatabase *mysql.MysqlDatabase `json:"mysqlDatabase" gorm:"-"` + MariadbDatabase *mariadb.MariadbDatabase `json:"mariadbDatabase" gorm:"-"` + MongodbDatabase *mongodb.MongodbDatabase `json:"mongodbDatabase" gorm:"-"` + + FailMessage *string `json:"failMessage" gorm:"column:fail_message"` + + RestoreDurationMs int64 `json:"restoreDurationMs" gorm:"column:restore_duration_ms;default:0"` + CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;default:now()"` +} diff --git a/backend/internal/features/restores/repository.go b/backend/internal/features/restores/core/repository.go similarity index 62% rename from backend/internal/features/restores/repository.go rename to backend/internal/features/restores/core/repository.go index f01f5ad..b9f5eba 100644 --- a/backend/internal/features/restores/repository.go +++ b/backend/internal/features/restores/core/repository.go @@ -1,8 +1,6 @@ -package restores +package restores_core import ( - "databasus-backend/internal/features/restores/enums" - "databasus-backend/internal/features/restores/models" "databasus-backend/internal/storage" "github.com/google/uuid" @@ -10,24 +8,24 @@ import ( type RestoreRepository struct{} -func (r *RestoreRepository) Save(restore *models.Restore) error { +func (r *RestoreRepository) Save(restore *Restore) error { db := storage.GetDb() isNew := restore.ID == uuid.Nil if isNew { restore.ID = uuid.New() return db.Create(restore). - Omit("Backup"). + Omit("Backup", "PostgresqlDatabase", "MysqlDatabase", "MariadbDatabase", "MongodbDatabase"). Error } return db.Save(restore). - Omit("Backup"). + Omit("Backup", "PostgresqlDatabase", "MysqlDatabase", "MariadbDatabase", "MongodbDatabase"). Error } -func (r *RestoreRepository) FindByBackupID(backupID uuid.UUID) ([]*models.Restore, error) { - var restores []*models.Restore +func (r *RestoreRepository) FindByBackupID(backupID uuid.UUID) ([]*Restore, error) { + var restores []*Restore if err := storage. GetDb(). @@ -41,8 +39,8 @@ func (r *RestoreRepository) FindByBackupID(backupID uuid.UUID) ([]*models.Restor return restores, nil } -func (r *RestoreRepository) FindByID(id uuid.UUID) (*models.Restore, error) { - var restore models.Restore +func (r *RestoreRepository) FindByID(id uuid.UUID) (*Restore, error) { + var restore Restore if err := storage. GetDb(). @@ -55,8 +53,8 @@ func (r *RestoreRepository) FindByID(id uuid.UUID) (*models.Restore, error) { return &restore, nil } -func (r *RestoreRepository) FindByStatus(status enums.RestoreStatus) ([]*models.Restore, error) { - var restores []*models.Restore +func (r *RestoreRepository) FindByStatus(status RestoreStatus) ([]*Restore, error) { + var restores []*Restore if err := storage. GetDb(). @@ -71,5 +69,5 @@ func (r *RestoreRepository) FindByStatus(status enums.RestoreStatus) ([]*models. } func (r *RestoreRepository) DeleteByID(id uuid.UUID) error { - return storage.GetDb().Delete(&models.Restore{}, "id = ?", id).Error + return storage.GetDb().Delete(&Restore{}, "id = ?", id).Error } diff --git a/backend/internal/features/restores/di.go b/backend/internal/features/restores/di.go index 7bafd87..2e1230b 100644 --- a/backend/internal/features/restores/di.go +++ b/backend/internal/features/restores/di.go @@ -6,6 +6,7 @@ import ( backups_config "databasus-backend/internal/features/backups/config" "databasus-backend/internal/features/databases" "databasus-backend/internal/features/disk" + restores_core "databasus-backend/internal/features/restores/core" "databasus-backend/internal/features/restores/usecases" "databasus-backend/internal/features/storages" workspaces_services "databasus-backend/internal/features/workspaces/services" @@ -13,7 +14,7 @@ import ( "databasus-backend/internal/util/logger" ) -var restoreRepository = &RestoreRepository{} +var restoreRepository = &restores_core.RestoreRepository{} var restoreService = &RestoreService{ backups.GetBackupService(), restoreRepository, @@ -31,19 +32,10 @@ var restoreController = &RestoreController{ restoreService, } -var restoreBackgroundService = &RestoreBackgroundService{ - restoreRepository, - logger.GetLogger(), -} - func GetRestoreController() *RestoreController { return restoreController } -func GetRestoreBackgroundService() *RestoreBackgroundService { - return restoreBackgroundService -} - func SetupDependencies() { backups.GetBackupService().AddBackupRemoveListener(restoreService) } diff --git a/backend/internal/features/restores/models/model.go b/backend/internal/features/restores/models/model.go deleted file mode 100644 index 7acb3a7..0000000 --- a/backend/internal/features/restores/models/model.go +++ /dev/null @@ -1,22 +0,0 @@ -package models - -import ( - backups_core "databasus-backend/internal/features/backups/backups/core" - "databasus-backend/internal/features/restores/enums" - "time" - - "github.com/google/uuid" -) - -type Restore struct { - ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"` - Status enums.RestoreStatus `json:"status" gorm:"column:status;type:text;not null"` - - BackupID uuid.UUID `json:"backupId" gorm:"column:backup_id;type:uuid;not null"` - Backup *backups_core.Backup - - FailMessage *string `json:"failMessage" gorm:"column:fail_message"` - - RestoreDurationMs int64 `json:"restoreDurationMs" gorm:"column:restore_duration_ms;default:0"` - CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;default:now()"` -} diff --git a/backend/internal/features/restores/restoring/di.go b/backend/internal/features/restores/restoring/di.go new file mode 100644 index 0000000..ca76927 --- /dev/null +++ b/backend/internal/features/restores/restoring/di.go @@ -0,0 +1,73 @@ +package restoring + +import ( + "time" + + "github.com/google/uuid" + + "databasus-backend/internal/features/backups/backups" + backups_config "databasus-backend/internal/features/backups/config" + "databasus-backend/internal/features/databases" + restores_core "databasus-backend/internal/features/restores/core" + "databasus-backend/internal/features/restores/usecases" + "databasus-backend/internal/features/storages" + cache_utils "databasus-backend/internal/util/cache" + "databasus-backend/internal/util/encryption" + "databasus-backend/internal/util/logger" +) + +var restoreRepository = &restores_core.RestoreRepository{} + +var restoreNodesRegistry = &RestoreNodesRegistry{ + cache_utils.GetValkeyClient(), + logger.GetLogger(), + cache_utils.DefaultCacheTimeout, + cache_utils.NewPubSubManager(), + cache_utils.NewPubSubManager(), +} + +var restoreDatabaseCache = cache_utils.NewCacheUtil[RestoreDatabaseCache]( + cache_utils.GetValkeyClient(), + "restore_db:", +) + +var restorerNode = &RestorerNode{ + uuid.New(), + databases.GetDatabaseService(), + backups.GetBackupService(), + encryption.GetFieldEncryptor(), + restoreRepository, + backups_config.GetBackupConfigService(), + storages.GetStorageService(), + restoreNodesRegistry, + logger.GetLogger(), + usecases.GetRestoreBackupUsecase(), + restoreDatabaseCache, + time.Time{}, +} + +var restoresScheduler = &RestoresScheduler{ + restoreRepository: restoreRepository, + backupService: backups.GetBackupService(), + storageService: storages.GetStorageService(), + backupConfigService: backups_config.GetBackupConfigService(), + restoreNodesRegistry: restoreNodesRegistry, + lastCheckTime: time.Now().UTC(), + logger: logger.GetLogger(), + restoreToNodeRelations: make(map[uuid.UUID]RestoreToNodeRelation), + restorerNode: restorerNode, + cacheUtil: restoreDatabaseCache, + completionSubscriptionID: uuid.Nil, +} + +func GetRestoresScheduler() *RestoresScheduler { + return restoresScheduler +} + +func GetRestorerNode() *RestorerNode { + return restorerNode +} + +func GetRestoreNodesRegistry() *RestoreNodesRegistry { + return restoreNodesRegistry +} diff --git a/backend/internal/features/restores/restoring/dto.go b/backend/internal/features/restores/restoring/dto.go new file mode 100644 index 0000000..eb4a3cb --- /dev/null +++ b/backend/internal/features/restores/restoring/dto.go @@ -0,0 +1,45 @@ +package restoring + +import ( + "databasus-backend/internal/features/databases/databases/mariadb" + "databasus-backend/internal/features/databases/databases/mongodb" + "databasus-backend/internal/features/databases/databases/mysql" + "databasus-backend/internal/features/databases/databases/postgresql" + "time" + + "github.com/google/uuid" +) + +type RestoreDatabaseCache struct { + PostgresqlDatabase *postgresql.PostgresqlDatabase `json:"postgresqlDatabase,omitempty"` + MysqlDatabase *mysql.MysqlDatabase `json:"mysqlDatabase,omitempty"` + MariadbDatabase *mariadb.MariadbDatabase `json:"mariadbDatabase,omitempty"` + MongodbDatabase *mongodb.MongodbDatabase `json:"mongodbDatabase,omitempty"` +} + +type RestoreToNodeRelation struct { + NodeID uuid.UUID `json:"nodeId"` + RestoreIDs []uuid.UUID `json:"restoreIds"` +} + +type RestoreNode struct { + ID uuid.UUID `json:"id"` + ThroughputMBs int `json:"throughputMBs"` + LastHeartbeat time.Time `json:"lastHeartbeat"` +} + +type RestoreNodeStats struct { + ID uuid.UUID `json:"id"` + ActiveRestores int `json:"activeRestores"` +} + +type RestoreSubmitMessage struct { + NodeID uuid.UUID `json:"nodeId"` + RestoreID uuid.UUID `json:"restoreId"` + IsCallNotifier bool `json:"isCallNotifier"` +} + +type RestoreCompletionMessage struct { + NodeID uuid.UUID `json:"nodeId"` + RestoreID uuid.UUID `json:"restoreId"` +} diff --git a/backend/internal/features/restores/restoring/mocks.go b/backend/internal/features/restores/restoring/mocks.go new file mode 100644 index 0000000..df02d75 --- /dev/null +++ b/backend/internal/features/restores/restoring/mocks.go @@ -0,0 +1,61 @@ +package restoring + +import ( + "errors" + + backups_core "databasus-backend/internal/features/backups/backups/core" + backups_config "databasus-backend/internal/features/backups/config" + "databasus-backend/internal/features/databases" + restores_core "databasus-backend/internal/features/restores/core" + "databasus-backend/internal/features/storages" +) + +type MockSuccessRestoreUsecase struct{} + +func (uc *MockSuccessRestoreUsecase) Execute( + backupConfig *backups_config.BackupConfig, + restore restores_core.Restore, + originalDB *databases.Database, + restoringToDB *databases.Database, + backup *backups_core.Backup, + storage *storages.Storage, + isExcludeExtensions bool, +) error { + return nil +} + +type MockFailedRestoreUsecase struct{} + +func (uc *MockFailedRestoreUsecase) Execute( + backupConfig *backups_config.BackupConfig, + restore restores_core.Restore, + originalDB *databases.Database, + restoringToDB *databases.Database, + backup *backups_core.Backup, + storage *storages.Storage, + isExcludeExtensions bool, +) error { + return errors.New("restore failed") +} + +type MockCaptureCredentialsRestoreUsecase struct { + CalledChan chan *databases.Database + ShouldSucceed bool +} + +func (uc *MockCaptureCredentialsRestoreUsecase) Execute( + backupConfig *backups_config.BackupConfig, + restore restores_core.Restore, + originalDB *databases.Database, + restoringToDB *databases.Database, + backup *backups_core.Backup, + storage *storages.Storage, + isExcludeExtensions bool, +) error { + uc.CalledChan <- restoringToDB + + if uc.ShouldSucceed { + return nil + } + return errors.New("mock restore failed") +} diff --git a/backend/internal/features/restores/restoring/registry.go b/backend/internal/features/restores/restoring/registry.go new file mode 100644 index 0000000..afbe9fc --- /dev/null +++ b/backend/internal/features/restores/restoring/registry.go @@ -0,0 +1,634 @@ +package restoring + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "strings" + "time" + + cache_utils "databasus-backend/internal/util/cache" + + "github.com/google/uuid" + "github.com/valkey-io/valkey-go" +) + +const ( + nodeInfoKeyPrefix = "restore:node:" + nodeInfoKeySuffix = ":info" + nodeActiveRestoresPrefix = "restore:node:" + nodeActiveRestoresSuffix = ":active_restores" + restoreSubmitChannel = "restore:submit" + restoreCompletionChannel = "restore:completion" + + deadNodeThreshold = 2 * time.Minute + cleanupTickerInterval = 1 * time.Second +) + +// RestoreNodesRegistry helps to sync restores scheduler and restore nodes. +// +// Features: +// - Track node availability and load level +// - Assign from scheduler to node restores needed to be processed +// - Notify scheduler from node about restore completion +// +// Important things to remember: +// - Nodes without heartbeat for more than 2 minutes are not included +// in available nodes list and stats +// +// Cleanup dead nodes performed on 2 levels: +// - List and stats functions do not return dead nodes +// - Periodically dead nodes are cleaned up in cache (to not +// accumulate too many dead nodes in cache) +type RestoreNodesRegistry struct { + client valkey.Client + logger *slog.Logger + timeout time.Duration + pubsubRestores *cache_utils.PubSubManager + pubsubCompletions *cache_utils.PubSubManager +} + +func (r *RestoreNodesRegistry) Run(ctx context.Context) { + if err := r.cleanupDeadNodes(); err != nil { + r.logger.Error("Failed to cleanup dead nodes on startup", "error", err) + } + + ticker := time.NewTicker(cleanupTickerInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := r.cleanupDeadNodes(); err != nil { + r.logger.Error("Failed to cleanup dead nodes", "error", err) + } + } + } +} + +func (r *RestoreNodesRegistry) GetAvailableNodes() ([]RestoreNode, error) { + ctx, cancel := context.WithTimeout(context.Background(), r.timeout) + defer cancel() + + var allKeys []string + cursor := uint64(0) + pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix + + for { + result := r.client.Do( + ctx, + r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(), + ) + + if result.Error() != nil { + return nil, fmt.Errorf("failed to scan node keys: %w", result.Error()) + } + + scanResult, err := result.AsScanEntry() + if err != nil { + return nil, fmt.Errorf("failed to parse scan result: %w", err) + } + + allKeys = append(allKeys, scanResult.Elements...) + + cursor = scanResult.Cursor + if cursor == 0 { + break + } + } + + if len(allKeys) == 0 { + return []RestoreNode{}, nil + } + + keyDataMap, err := r.pipelineGetKeys(allKeys) + if err != nil { + return nil, fmt.Errorf("failed to pipeline get node keys: %w", err) + } + + threshold := time.Now().UTC().Add(-deadNodeThreshold) + var nodes []RestoreNode + + for key, data := range keyDataMap { + // Skip if the key doesn't exist (data is empty) + if len(data) == 0 { + continue + } + + var node RestoreNode + if err := json.Unmarshal(data, &node); err != nil { + r.logger.Warn("Failed to unmarshal node data", "key", key, "error", err) + continue + } + + // Skip nodes with zero/uninitialized heartbeat + if node.LastHeartbeat.IsZero() { + continue + } + + if node.LastHeartbeat.Before(threshold) { + continue + } + + nodes = append(nodes, node) + } + + return nodes, nil +} + +func (r *RestoreNodesRegistry) GetRestoreNodesStats() ([]RestoreNodeStats, error) { + ctx, cancel := context.WithTimeout(context.Background(), r.timeout) + defer cancel() + + var allKeys []string + cursor := uint64(0) + pattern := nodeActiveRestoresPrefix + "*" + nodeActiveRestoresSuffix + + for { + result := r.client.Do( + ctx, + r.client.B().Scan().Cursor(cursor).Match(pattern).Count(100).Build(), + ) + + if result.Error() != nil { + return nil, fmt.Errorf("failed to scan active restores keys: %w", result.Error()) + } + + scanResult, err := result.AsScanEntry() + if err != nil { + return nil, fmt.Errorf("failed to parse scan result: %w", err) + } + + allKeys = append(allKeys, scanResult.Elements...) + + cursor = scanResult.Cursor + if cursor == 0 { + break + } + } + + if len(allKeys) == 0 { + return []RestoreNodeStats{}, nil + } + + keyDataMap, err := r.pipelineGetKeys(allKeys) + if err != nil { + return nil, fmt.Errorf("failed to pipeline get active restores keys: %w", err) + } + + var nodeInfoKeys []string + nodeIDToStatsKey := make(map[string]string) + for key := range keyDataMap { + nodeID := r.extractNodeIDFromKey(key, nodeActiveRestoresPrefix, nodeActiveRestoresSuffix) + nodeIDStr := nodeID.String() + infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeIDStr, nodeInfoKeySuffix) + nodeInfoKeys = append(nodeInfoKeys, infoKey) + nodeIDToStatsKey[infoKey] = key + } + + nodeInfoMap, err := r.pipelineGetKeys(nodeInfoKeys) + if err != nil { + return nil, fmt.Errorf("failed to pipeline get node info keys: %w", err) + } + + threshold := time.Now().UTC().Add(-deadNodeThreshold) + var stats []RestoreNodeStats + for infoKey, nodeData := range nodeInfoMap { + // Skip if the info key doesn't exist (nodeData is empty) + if len(nodeData) == 0 { + continue + } + + var node RestoreNode + if err := json.Unmarshal(nodeData, &node); err != nil { + r.logger.Warn("Failed to unmarshal node data", "key", infoKey, "error", err) + continue + } + + // Skip nodes with zero/uninitialized heartbeat + if node.LastHeartbeat.IsZero() { + continue + } + + if node.LastHeartbeat.Before(threshold) { + continue + } + + statsKey := nodeIDToStatsKey[infoKey] + tasksData := keyDataMap[statsKey] + count, err := r.parseIntFromBytes(tasksData) + if err != nil { + r.logger.Warn("Failed to parse active restores count", "key", statsKey, "error", err) + continue + } + + stat := RestoreNodeStats{ + ID: node.ID, + ActiveRestores: int(count), + } + stats = append(stats, stat) + } + + return stats, nil +} + +func (r *RestoreNodesRegistry) IncrementRestoresInProgress(nodeID uuid.UUID) error { + ctx, cancel := context.WithTimeout(context.Background(), r.timeout) + defer cancel() + + key := fmt.Sprintf( + "%s%s%s", + nodeActiveRestoresPrefix, + nodeID.String(), + nodeActiveRestoresSuffix, + ) + result := r.client.Do(ctx, r.client.B().Incr().Key(key).Build()) + + if result.Error() != nil { + return fmt.Errorf( + "failed to increment restores in progress for node %s: %w", + nodeID, + result.Error(), + ) + } + + return nil +} + +func (r *RestoreNodesRegistry) DecrementRestoresInProgress(nodeID uuid.UUID) error { + ctx, cancel := context.WithTimeout(context.Background(), r.timeout) + defer cancel() + + key := fmt.Sprintf( + "%s%s%s", + nodeActiveRestoresPrefix, + nodeID.String(), + nodeActiveRestoresSuffix, + ) + result := r.client.Do(ctx, r.client.B().Decr().Key(key).Build()) + + if result.Error() != nil { + return fmt.Errorf( + "failed to decrement restores in progress for node %s: %w", + nodeID, + result.Error(), + ) + } + + newValue, err := result.AsInt64() + if err != nil { + return fmt.Errorf("failed to parse decremented value for node %s: %w", nodeID, err) + } + + if newValue < 0 { + setCtx, setCancel := context.WithTimeout(context.Background(), r.timeout) + r.client.Do(setCtx, r.client.B().Set().Key(key).Value("0").Build()) + setCancel() + r.logger.Warn("Active restores counter went below 0, reset to 0", "nodeID", nodeID) + } + + return nil +} + +func (r *RestoreNodesRegistry) HearthbeatNodeInRegistry( + now time.Time, + restoreNode RestoreNode, +) error { + if now.IsZero() { + return fmt.Errorf("cannot register node with zero heartbeat timestamp") + } + + ctx, cancel := context.WithTimeout(context.Background(), r.timeout) + defer cancel() + + restoreNode.LastHeartbeat = now + + data, err := json.Marshal(restoreNode) + if err != nil { + return fmt.Errorf("failed to marshal restore node: %w", err) + } + + key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, restoreNode.ID.String(), nodeInfoKeySuffix) + result := r.client.Do( + ctx, + r.client.B().Set().Key(key).Value(string(data)).Build(), + ) + + if result.Error() != nil { + return fmt.Errorf("failed to register node %s: %w", restoreNode.ID, result.Error()) + } + + return nil +} + +func (r *RestoreNodesRegistry) UnregisterNodeFromRegistry(restoreNode RestoreNode) error { + ctx, cancel := context.WithTimeout(context.Background(), r.timeout) + defer cancel() + + infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, restoreNode.ID.String(), nodeInfoKeySuffix) + counterKey := fmt.Sprintf( + "%s%s%s", + nodeActiveRestoresPrefix, + restoreNode.ID.String(), + nodeActiveRestoresSuffix, + ) + + result := r.client.Do( + ctx, + r.client.B().Del().Key(infoKey, counterKey).Build(), + ) + + if result.Error() != nil { + return fmt.Errorf("failed to unregister node %s: %w", restoreNode.ID, result.Error()) + } + + r.logger.Info("Unregistered node from registry", "nodeID", restoreNode.ID) + return nil +} + +func (r *RestoreNodesRegistry) AssignRestoreToNode( + targetNodeID uuid.UUID, + restoreID uuid.UUID, + isCallNotifier bool, +) error { + ctx := context.Background() + + message := RestoreSubmitMessage{ + NodeID: targetNodeID, + RestoreID: restoreID, + IsCallNotifier: isCallNotifier, + } + + messageJSON, err := json.Marshal(message) + if err != nil { + return fmt.Errorf("failed to marshal restore submit message: %w", err) + } + + err = r.pubsubRestores.Publish(ctx, restoreSubmitChannel, string(messageJSON)) + if err != nil { + return fmt.Errorf("failed to publish restore submit message: %w", err) + } + + return nil +} + +func (r *RestoreNodesRegistry) SubscribeNodeForRestoresAssignment( + nodeID uuid.UUID, + handler func(restoreID uuid.UUID, isCallNotifier bool), +) error { + ctx := context.Background() + + wrappedHandler := func(message string) { + var msg RestoreSubmitMessage + if err := json.Unmarshal([]byte(message), &msg); err != nil { + r.logger.Warn("Failed to unmarshal restore submit message", "error", err) + return + } + + if msg.NodeID != nodeID { + return + } + + handler(msg.RestoreID, msg.IsCallNotifier) + } + + err := r.pubsubRestores.Subscribe(ctx, restoreSubmitChannel, wrappedHandler) + if err != nil { + return fmt.Errorf("failed to subscribe to restore submit channel: %w", err) + } + + r.logger.Info("Subscribed to restore submit channel", "nodeID", nodeID) + return nil +} + +func (r *RestoreNodesRegistry) UnsubscribeNodeForRestoresAssignments() error { + err := r.pubsubRestores.Close() + if err != nil { + return fmt.Errorf("failed to unsubscribe from restore submit channel: %w", err) + } + + r.logger.Info("Unsubscribed from restore submit channel") + return nil +} + +func (r *RestoreNodesRegistry) PublishRestoreCompletion( + nodeID uuid.UUID, + restoreID uuid.UUID, +) error { + ctx := context.Background() + + message := RestoreCompletionMessage{ + NodeID: nodeID, + RestoreID: restoreID, + } + + messageJSON, err := json.Marshal(message) + if err != nil { + return fmt.Errorf("failed to marshal restore completion message: %w", err) + } + + err = r.pubsubCompletions.Publish(ctx, restoreCompletionChannel, string(messageJSON)) + if err != nil { + return fmt.Errorf("failed to publish restore completion message: %w", err) + } + + return nil +} + +func (r *RestoreNodesRegistry) SubscribeForRestoresCompletions( + handler func(nodeID uuid.UUID, restoreID uuid.UUID), +) error { + ctx := context.Background() + + wrappedHandler := func(message string) { + var msg RestoreCompletionMessage + if err := json.Unmarshal([]byte(message), &msg); err != nil { + r.logger.Warn("Failed to unmarshal restore completion message", "error", err) + return + } + + handler(msg.NodeID, msg.RestoreID) + } + + err := r.pubsubCompletions.Subscribe(ctx, restoreCompletionChannel, wrappedHandler) + if err != nil { + return fmt.Errorf("failed to subscribe to restore completion channel: %w", err) + } + + r.logger.Info("Subscribed to restore completion channel") + return nil +} + +func (r *RestoreNodesRegistry) UnsubscribeForRestoresCompletions() error { + err := r.pubsubCompletions.Close() + if err != nil { + return fmt.Errorf("failed to unsubscribe from restore completion channel: %w", err) + } + + r.logger.Info("Unsubscribed from restore completion channel") + return nil +} + +func (r *RestoreNodesRegistry) extractNodeIDFromKey(key, prefix, suffix string) uuid.UUID { + nodeIDStr := strings.TrimPrefix(key, prefix) + nodeIDStr = strings.TrimSuffix(nodeIDStr, suffix) + + nodeID, err := uuid.Parse(nodeIDStr) + if err != nil { + r.logger.Warn("Failed to parse node ID from key", "key", key, "error", err) + return uuid.Nil + } + + return nodeID +} + +func (r *RestoreNodesRegistry) pipelineGetKeys(keys []string) (map[string][]byte, error) { + if len(keys) == 0 { + return make(map[string][]byte), nil + } + + ctx, cancel := context.WithTimeout(context.Background(), r.timeout) + defer cancel() + + commands := make([]valkey.Completed, 0, len(keys)) + for _, key := range keys { + commands = append(commands, r.client.B().Get().Key(key).Build()) + } + + results := r.client.DoMulti(ctx, commands...) + + keyDataMap := make(map[string][]byte, len(keys)) + for i, result := range results { + if result.Error() != nil { + r.logger.Warn("Failed to get key in pipeline", "key", keys[i], "error", result.Error()) + continue + } + + data, err := result.AsBytes() + if err != nil { + r.logger.Warn("Failed to parse key data in pipeline", "key", keys[i], "error", err) + continue + } + + keyDataMap[keys[i]] = data + } + + return keyDataMap, nil +} + +func (r *RestoreNodesRegistry) parseIntFromBytes(data []byte) (int64, error) { + str := string(data) + var count int64 + _, err := fmt.Sscanf(str, "%d", &count) + if err != nil { + return 0, fmt.Errorf("failed to parse integer from bytes: %w", err) + } + return count, nil +} + +func (r *RestoreNodesRegistry) cleanupDeadNodes() error { + ctx, cancel := context.WithTimeout(context.Background(), r.timeout) + defer cancel() + + var allKeys []string + cursor := uint64(0) + pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix + + for { + result := r.client.Do( + ctx, + r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(), + ) + + if result.Error() != nil { + return fmt.Errorf("failed to scan node keys: %w", result.Error()) + } + + scanResult, err := result.AsScanEntry() + if err != nil { + return fmt.Errorf("failed to parse scan result: %w", err) + } + + allKeys = append(allKeys, scanResult.Elements...) + + cursor = scanResult.Cursor + if cursor == 0 { + break + } + } + + if len(allKeys) == 0 { + return nil + } + + keyDataMap, err := r.pipelineGetKeys(allKeys) + if err != nil { + return fmt.Errorf("failed to pipeline get node keys: %w", err) + } + + threshold := time.Now().UTC().Add(-deadNodeThreshold) + var deadNodeKeys []string + + for key, data := range keyDataMap { + // Skip if the key doesn't exist (data is empty) + if len(data) == 0 { + continue + } + + var node RestoreNode + if err := json.Unmarshal(data, &node); err != nil { + r.logger.Warn("Failed to unmarshal node data during cleanup", "key", key, "error", err) + continue + } + + // Skip nodes with zero/uninitialized heartbeat + if node.LastHeartbeat.IsZero() { + continue + } + + if node.LastHeartbeat.Before(threshold) { + nodeID := node.ID.String() + infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeID, nodeInfoKeySuffix) + statsKey := fmt.Sprintf( + "%s%s%s", + nodeActiveRestoresPrefix, + nodeID, + nodeActiveRestoresSuffix, + ) + + deadNodeKeys = append(deadNodeKeys, infoKey, statsKey) + r.logger.Info( + "Marking node for cleanup", + "nodeID", nodeID, + "lastHeartbeat", node.LastHeartbeat, + "threshold", threshold, + ) + } + } + + if len(deadNodeKeys) == 0 { + return nil + } + + delCtx, delCancel := context.WithTimeout(context.Background(), r.timeout) + defer delCancel() + + result := r.client.Do( + delCtx, + r.client.B().Del().Key(deadNodeKeys...).Build(), + ) + + if result.Error() != nil { + return fmt.Errorf("failed to delete dead node keys: %w", result.Error()) + } + + deletedCount, err := result.AsInt64() + if err != nil { + return fmt.Errorf("failed to parse deleted count: %w", err) + } + + r.logger.Info("Cleaned up dead nodes", "deletedKeysCount", deletedCount) + return nil +} diff --git a/backend/internal/features/tasks/registry/registry_test.go b/backend/internal/features/restores/restoring/registry_test.go similarity index 56% rename from backend/internal/features/tasks/registry/registry_test.go rename to backend/internal/features/restores/restoring/registry_test.go index 026bd9c..fd6169c 100644 --- a/backend/internal/features/tasks/registry/registry_test.go +++ b/backend/internal/features/restores/restoring/registry_test.go @@ -1,4 +1,4 @@ -package task_registry +package restoring import ( "context" @@ -17,9 +17,8 @@ import ( func Test_HearthbeatNodeInRegistry_RegistersNodeWithTTL(t *testing.T) { cache_utils.ClearAllCache() registry := createTestRegistry() - node := createTestTaskNode() + node := createTestRestoreNode() defer cleanupTestNode(registry, node) - defer cache_utils.ClearAllCache() err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node) assert.NoError(t, err) @@ -33,14 +32,13 @@ func Test_HearthbeatNodeInRegistry_RegistersNodeWithTTL(t *testing.T) { func Test_UnregisterNodeFromRegistry_RemovesNodeAndCounter(t *testing.T) { cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() registry := createTestRegistry() - node := createTestTaskNode() + node := createTestRestoreNode() err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node) assert.NoError(t, err) - err = registry.IncrementTasksInProgress(node.ID.String()) + err = registry.IncrementRestoresInProgress(node.ID) assert.NoError(t, err) err = registry.UnregisterNodeFromRegistry(node) @@ -50,18 +48,17 @@ func Test_UnregisterNodeFromRegistry_RemovesNodeAndCounter(t *testing.T) { assert.NoError(t, err) assert.Empty(t, nodes) - stats, err := registry.GetNodesStats() + stats, err := registry.GetRestoreNodesStats() assert.NoError(t, err) assert.Empty(t, stats) } func Test_GetAvailableNodes_ReturnsAllRegisteredNodes(t *testing.T) { cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() registry := createTestRegistry() - node1 := createTestTaskNode() - node2 := createTestTaskNode() - node3 := createTestTaskNode() + node1 := createTestRestoreNode() + node2 := createTestRestoreNode() + node3 := createTestRestoreNode() defer cleanupTestNode(registry, node1) defer cleanupTestNode(registry, node2) defer cleanupTestNode(registry, node3) @@ -88,7 +85,6 @@ func Test_GetAvailableNodes_ReturnsAllRegisteredNodes(t *testing.T) { func Test_GetAvailableNodes_WhenNoNodesExist_ReturnsEmptySlice(t *testing.T) { cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() registry := createTestRegistry() nodes, err := registry.GetAvailableNodes() @@ -97,96 +93,92 @@ func Test_GetAvailableNodes_WhenNoNodesExist_ReturnsEmptySlice(t *testing.T) { assert.Empty(t, nodes) } -func Test_IncrementTasksInProgress_IncrementsCounter(t *testing.T) { +func Test_IncrementRestoresInProgress_IncrementsCounter(t *testing.T) { cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() registry := createTestRegistry() - node := createTestTaskNode() + node := createTestRestoreNode() defer cleanupTestNode(registry, node) err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node) assert.NoError(t, err) - err = registry.IncrementTasksInProgress(node.ID.String()) + err = registry.IncrementRestoresInProgress(node.ID) assert.NoError(t, err) - stats, err := registry.GetNodesStats() + stats, err := registry.GetRestoreNodesStats() assert.NoError(t, err) assert.Len(t, stats, 1) assert.Equal(t, node.ID, stats[0].ID) - assert.Equal(t, 1, stats[0].ActiveTasks) + assert.Equal(t, 1, stats[0].ActiveRestores) - err = registry.IncrementTasksInProgress(node.ID.String()) + err = registry.IncrementRestoresInProgress(node.ID) assert.NoError(t, err) - stats, err = registry.GetNodesStats() + stats, err = registry.GetRestoreNodesStats() assert.NoError(t, err) assert.Len(t, stats, 1) - assert.Equal(t, 2, stats[0].ActiveTasks) + assert.Equal(t, 2, stats[0].ActiveRestores) } -func Test_DecrementTasksInProgress_DecrementsCounter(t *testing.T) { +func Test_DecrementRestoresInProgress_DecrementsCounter(t *testing.T) { cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() registry := createTestRegistry() - node := createTestTaskNode() + node := createTestRestoreNode() defer cleanupTestNode(registry, node) err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node) assert.NoError(t, err) - err = registry.IncrementTasksInProgress(node.ID.String()) + err = registry.IncrementRestoresInProgress(node.ID) assert.NoError(t, err) - err = registry.IncrementTasksInProgress(node.ID.String()) + err = registry.IncrementRestoresInProgress(node.ID) assert.NoError(t, err) - err = registry.IncrementTasksInProgress(node.ID.String()) + err = registry.IncrementRestoresInProgress(node.ID) assert.NoError(t, err) - stats, err := registry.GetNodesStats() + stats, err := registry.GetRestoreNodesStats() assert.NoError(t, err) - assert.Equal(t, 3, stats[0].ActiveTasks) + assert.Equal(t, 3, stats[0].ActiveRestores) - err = registry.DecrementTasksInProgress(node.ID.String()) + err = registry.DecrementRestoresInProgress(node.ID) assert.NoError(t, err) - stats, err = registry.GetNodesStats() + stats, err = registry.GetRestoreNodesStats() assert.NoError(t, err) - assert.Equal(t, 2, stats[0].ActiveTasks) + assert.Equal(t, 2, stats[0].ActiveRestores) - err = registry.DecrementTasksInProgress(node.ID.String()) + err = registry.DecrementRestoresInProgress(node.ID) assert.NoError(t, err) - stats, err = registry.GetNodesStats() + stats, err = registry.GetRestoreNodesStats() assert.NoError(t, err) - assert.Equal(t, 1, stats[0].ActiveTasks) + assert.Equal(t, 1, stats[0].ActiveRestores) } -func Test_DecrementTasksInProgress_WhenNegative_ResetsToZero(t *testing.T) { +func Test_DecrementRestoresInProgress_WhenNegative_ResetsToZero(t *testing.T) { cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() registry := createTestRegistry() - node := createTestTaskNode() + node := createTestRestoreNode() defer cleanupTestNode(registry, node) err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node) assert.NoError(t, err) - err = registry.DecrementTasksInProgress(node.ID.String()) + err = registry.DecrementRestoresInProgress(node.ID) assert.NoError(t, err) - stats, err := registry.GetNodesStats() + stats, err := registry.GetRestoreNodesStats() assert.NoError(t, err) assert.Len(t, stats, 1) - assert.Equal(t, 0, stats[0].ActiveTasks) + assert.Equal(t, 0, stats[0].ActiveRestores) } -func Test_GetTaskNodesStats_ReturnsStatsForAllNodes(t *testing.T) { +func Test_GetRestoreNodesStats_ReturnsStatsForAllNodes(t *testing.T) { cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() registry := createTestRegistry() - node1 := createTestTaskNode() - node2 := createTestTaskNode() - node3 := createTestTaskNode() + node1 := createTestRestoreNode() + node2 := createTestRestoreNode() + node3 := createTestRestoreNode() defer cleanupTestNode(registry, node1) defer cleanupTestNode(registry, node2) defer cleanupTestNode(registry, node3) @@ -198,28 +190,28 @@ func Test_GetTaskNodesStats_ReturnsStatsForAllNodes(t *testing.T) { err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3) assert.NoError(t, err) - err = registry.IncrementTasksInProgress(node1.ID.String()) + err = registry.IncrementRestoresInProgress(node1.ID) assert.NoError(t, err) - err = registry.IncrementTasksInProgress(node2.ID.String()) + err = registry.IncrementRestoresInProgress(node2.ID) assert.NoError(t, err) - err = registry.IncrementTasksInProgress(node2.ID.String()) + err = registry.IncrementRestoresInProgress(node2.ID) assert.NoError(t, err) - err = registry.IncrementTasksInProgress(node3.ID.String()) + err = registry.IncrementRestoresInProgress(node3.ID) assert.NoError(t, err) - err = registry.IncrementTasksInProgress(node3.ID.String()) + err = registry.IncrementRestoresInProgress(node3.ID) assert.NoError(t, err) - err = registry.IncrementTasksInProgress(node3.ID.String()) + err = registry.IncrementRestoresInProgress(node3.ID) assert.NoError(t, err) - stats, err := registry.GetNodesStats() + stats, err := registry.GetRestoreNodesStats() assert.NoError(t, err) assert.Len(t, stats, 3) statsMap := make(map[uuid.UUID]int) for _, stat := range stats { - statsMap[stat.ID] = stat.ActiveTasks + statsMap[stat.ID] = stat.ActiveRestores } assert.Equal(t, 1, statsMap[node1.ID]) @@ -227,12 +219,11 @@ func Test_GetTaskNodesStats_ReturnsStatsForAllNodes(t *testing.T) { assert.Equal(t, 3, statsMap[node3.ID]) } -func Test_GetTaskNodesStats_WhenNoStats_ReturnsEmptySlice(t *testing.T) { +func Test_GetRestoreNodesStats_WhenNoStats_ReturnsEmptySlice(t *testing.T) { cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() registry := createTestRegistry() - stats, err := registry.GetNodesStats() + stats, err := registry.GetRestoreNodesStats() assert.NoError(t, err) assert.NotNil(t, stats) assert.Empty(t, stats) @@ -240,13 +231,12 @@ func Test_GetTaskNodesStats_WhenNoStats_ReturnsEmptySlice(t *testing.T) { func Test_MultipleNodes_RegisteredAndQueriedCorrectly(t *testing.T) { cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() registry := createTestRegistry() - node1 := createTestTaskNode() + node1 := createTestRestoreNode() node1.ThroughputMBs = 50 - node2 := createTestTaskNode() + node2 := createTestRestoreNode() node2.ThroughputMBs = 100 - node3 := createTestTaskNode() + node3 := createTestRestoreNode() node3.ThroughputMBs = 150 defer cleanupTestNode(registry, node1) defer cleanupTestNode(registry, node2) @@ -263,7 +253,7 @@ func Test_MultipleNodes_RegisteredAndQueriedCorrectly(t *testing.T) { assert.NoError(t, err) assert.Len(t, nodes, 3) - nodeMap := make(map[uuid.UUID]TaskNode) + nodeMap := make(map[uuid.UUID]RestoreNode) for _, node := range nodes { nodeMap[node.ID] = node } @@ -273,12 +263,11 @@ func Test_MultipleNodes_RegisteredAndQueriedCorrectly(t *testing.T) { assert.Equal(t, 150, nodeMap[node3.ID].ThroughputMBs) } -func Test_TaskCounters_TrackedSeparatelyPerNode(t *testing.T) { +func Test_RestoreCounters_TrackedSeparatelyPerNode(t *testing.T) { cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() registry := createTestRegistry() - node1 := createTestTaskNode() - node2 := createTestTaskNode() + node1 := createTestRestoreNode() + node2 := createTestRestoreNode() defer cleanupTestNode(registry, node1) defer cleanupTestNode(registry, node2) @@ -287,35 +276,35 @@ func Test_TaskCounters_TrackedSeparatelyPerNode(t *testing.T) { err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2) assert.NoError(t, err) - err = registry.IncrementTasksInProgress(node1.ID.String()) + err = registry.IncrementRestoresInProgress(node1.ID) assert.NoError(t, err) - err = registry.IncrementTasksInProgress(node1.ID.String()) + err = registry.IncrementRestoresInProgress(node1.ID) assert.NoError(t, err) - err = registry.IncrementTasksInProgress(node2.ID.String()) + err = registry.IncrementRestoresInProgress(node2.ID) assert.NoError(t, err) - stats, err := registry.GetNodesStats() + stats, err := registry.GetRestoreNodesStats() assert.NoError(t, err) assert.Len(t, stats, 2) statsMap := make(map[uuid.UUID]int) for _, stat := range stats { - statsMap[stat.ID] = stat.ActiveTasks + statsMap[stat.ID] = stat.ActiveRestores } assert.Equal(t, 2, statsMap[node1.ID]) assert.Equal(t, 1, statsMap[node2.ID]) - err = registry.DecrementTasksInProgress(node1.ID.String()) + err = registry.DecrementRestoresInProgress(node1.ID) assert.NoError(t, err) - stats, err = registry.GetNodesStats() + stats, err = registry.GetRestoreNodesStats() assert.NoError(t, err) statsMap = make(map[uuid.UUID]int) for _, stat := range stats { - statsMap[stat.ID] = stat.ActiveTasks + statsMap[stat.ID] = stat.ActiveRestores } assert.Equal(t, 1, statsMap[node1.ID]) @@ -324,9 +313,8 @@ func Test_TaskCounters_TrackedSeparatelyPerNode(t *testing.T) { func Test_GetAvailableNodes_SkipsInvalidJsonData(t *testing.T) { cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() registry := createTestRegistry() - node := createTestTaskNode() + node := createTestRestoreNode() defer cleanupTestNode(registry, node) err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node) @@ -354,7 +342,6 @@ func Test_GetAvailableNodes_SkipsInvalidJsonData(t *testing.T) { func Test_PipelineGetKeys_HandlesEmptyKeysList(t *testing.T) { cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() registry := createTestRegistry() keyDataMap, err := registry.pipelineGetKeys([]string{}) @@ -365,15 +352,13 @@ func Test_PipelineGetKeys_HandlesEmptyKeysList(t *testing.T) { func Test_HearthbeatNodeInRegistry_UpdatesLastHeartbeat(t *testing.T) { cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() registry := createTestRegistry() - node := createTestTaskNode() + node := createTestRestoreNode() originalHeartbeat := node.LastHeartbeat defer cleanupTestNode(registry, node) time.Sleep(10 * time.Millisecond) - node.LastHeartbeat = time.Now().UTC() err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node) assert.NoError(t, err) @@ -385,9 +370,8 @@ func Test_HearthbeatNodeInRegistry_UpdatesLastHeartbeat(t *testing.T) { func Test_HearthbeatNodeInRegistry_RejectsZeroTimestamp(t *testing.T) { cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() registry := createTestRegistry() - node := createTestTaskNode() + node := createTestRestoreNode() err := registry.HearthbeatNodeInRegistry(time.Time{}, node) assert.Error(t, err) @@ -398,571 +382,12 @@ func Test_HearthbeatNodeInRegistry_RejectsZeroTimestamp(t *testing.T) { assert.Len(t, nodes, 0) } -func createTestRegistry() *TaskNodesRegistry { - return &TaskNodesRegistry{ - cache_utils.GetValkeyClient(), - logger.GetLogger(), - cache_utils.DefaultCacheTimeout, - cache_utils.NewPubSubManager(), - cache_utils.NewPubSubManager(), - } -} - -func createTestTaskNode() TaskNode { - return TaskNode{ - ID: uuid.New(), - ThroughputMBs: 100, - LastHeartbeat: time.Now().UTC(), - } -} - -func cleanupTestNode(registry *TaskNodesRegistry, node TaskNode) { - registry.UnregisterNodeFromRegistry(node) -} - -func Test_AssignTaskToNode_PublishesJsonMessageToChannel(t *testing.T) { - cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() - registry := createTestRegistry() - node := createTestTaskNode() - taskID := uuid.New() - - err := registry.AssignTaskToNode(node.ID.String(), taskID, true) - assert.NoError(t, err) -} - -func Test_SubscribeNodeForTasksAssignment_ReceivesSubmittedTasksForMatchingNode(t *testing.T) { - cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() - registry := createTestRegistry() - node := createTestTaskNode() - taskID := uuid.New() - defer registry.UnsubscribeNodeForTasksAssignments() - - receivedTaskID := make(chan uuid.UUID, 1) - handler := func(id uuid.UUID, isCallNotifier bool) { - receivedTaskID <- id - } - - err := registry.SubscribeNodeForTasksAssignment(node.ID.String(), handler) - assert.NoError(t, err) - - time.Sleep(100 * time.Millisecond) - - err = registry.AssignTaskToNode(node.ID.String(), taskID, true) - assert.NoError(t, err) - - select { - case received := <-receivedTaskID: - assert.Equal(t, taskID, received) - case <-time.After(2 * time.Second): - t.Fatal("Timeout waiting for task message") - } -} - -func Test_SubscribeNodeForTasksAssignment_FiltersOutTasksForDifferentNode(t *testing.T) { - cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() - registry := createTestRegistry() - node1 := createTestTaskNode() - node2 := createTestTaskNode() - taskID := uuid.New() - defer registry.UnsubscribeNodeForTasksAssignments() - - receivedTaskID := make(chan uuid.UUID, 1) - handler := func(id uuid.UUID, isCallNotifier bool) { - receivedTaskID <- id - } - - err := registry.SubscribeNodeForTasksAssignment(node1.ID.String(), handler) - assert.NoError(t, err) - - time.Sleep(100 * time.Millisecond) - - err = registry.AssignTaskToNode(node2.ID.String(), taskID, false) - assert.NoError(t, err) - - select { - case <-receivedTaskID: - t.Fatal("Should not receive task for different node") - case <-time.After(500 * time.Millisecond): - } -} - -func Test_SubscribeNodeForTasksAssignment_ParsesJsonAndTaskIdCorrectly(t *testing.T) { - cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() - registry := createTestRegistry() - node := createTestTaskNode() - taskID1 := uuid.New() - taskID2 := uuid.New() - defer registry.UnsubscribeNodeForTasksAssignments() - - receivedTasks := make(chan uuid.UUID, 2) - handler := func(id uuid.UUID, isCallNotifier bool) { - receivedTasks <- id - } - - err := registry.SubscribeNodeForTasksAssignment(node.ID.String(), handler) - assert.NoError(t, err) - - time.Sleep(100 * time.Millisecond) - - err = registry.AssignTaskToNode(node.ID.String(), taskID1, true) - assert.NoError(t, err) - - err = registry.AssignTaskToNode(node.ID.String(), taskID2, false) - assert.NoError(t, err) - - received1 := <-receivedTasks - received2 := <-receivedTasks - - receivedIDs := []uuid.UUID{received1, received2} - assert.Contains(t, receivedIDs, taskID1) - assert.Contains(t, receivedIDs, taskID2) -} - -func Test_SubscribeNodeForTasksAssignment_HandlesInvalidJson(t *testing.T) { - cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() - registry := createTestRegistry() - node := createTestTaskNode() - defer registry.UnsubscribeNodeForTasksAssignments() - - receivedTaskID := make(chan uuid.UUID, 1) - handler := func(id uuid.UUID, isCallNotifier bool) { - receivedTaskID <- id - } - - err := registry.SubscribeNodeForTasksAssignment(node.ID.String(), handler) - assert.NoError(t, err) - - time.Sleep(100 * time.Millisecond) - - ctx := context.Background() - err = registry.pubsubTasks.Publish(ctx, "backup:submit", "invalid json") - assert.NoError(t, err) - - select { - case <-receivedTaskID: - t.Fatal("Should not receive task for invalid JSON") - case <-time.After(500 * time.Millisecond): - } -} - -func Test_UnsubscribeNodeForTasksAssignments_StopsReceivingMessages(t *testing.T) { - cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() - registry := createTestRegistry() - node := createTestTaskNode() - taskID1 := uuid.New() - taskID2 := uuid.New() - - receivedTaskID := make(chan uuid.UUID, 2) - handler := func(id uuid.UUID, isCallNotifier bool) { - receivedTaskID <- id - } - - err := registry.SubscribeNodeForTasksAssignment(node.ID.String(), handler) - assert.NoError(t, err) - - time.Sleep(100 * time.Millisecond) - - err = registry.AssignTaskToNode(node.ID.String(), taskID1, true) - assert.NoError(t, err) - - received := <-receivedTaskID - assert.Equal(t, taskID1, received) - - err = registry.UnsubscribeNodeForTasksAssignments() - assert.NoError(t, err) - - time.Sleep(100 * time.Millisecond) - - err = registry.AssignTaskToNode(node.ID.String(), taskID2, false) - assert.NoError(t, err) - - select { - case <-receivedTaskID: - t.Fatal("Should not receive task after unsubscribe") - case <-time.After(500 * time.Millisecond): - } -} - -func Test_SubscribeNodeForTasksAssignment_WhenAlreadySubscribed_ReturnsError(t *testing.T) { - cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() - registry := createTestRegistry() - node := createTestTaskNode() - defer registry.UnsubscribeNodeForTasksAssignments() - - handler := func(id uuid.UUID, isCallNotifier bool) {} - - err := registry.SubscribeNodeForTasksAssignment(node.ID.String(), handler) - assert.NoError(t, err) - - err = registry.SubscribeNodeForTasksAssignment(node.ID.String(), handler) - assert.Error(t, err) - assert.Contains(t, err.Error(), "already subscribed") -} - -func Test_MultipleNodes_EachReceivesOnlyTheirTasks(t *testing.T) { - cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() - registry1 := createTestRegistry() - registry2 := createTestRegistry() - registry3 := createTestRegistry() - - node1 := createTestTaskNode() - node2 := createTestTaskNode() - node3 := createTestTaskNode() - - taskID1 := uuid.New() - taskID2 := uuid.New() - taskID3 := uuid.New() - - defer registry1.UnsubscribeNodeForTasksAssignments() - defer registry2.UnsubscribeNodeForTasksAssignments() - defer registry3.UnsubscribeNodeForTasksAssignments() - defer cleanupTestNode(registry1, node1) - defer cleanupTestNode(registry1, node2) - defer cleanupTestNode(registry1, node3) - - receivedTasks1 := make(chan uuid.UUID, 3) - receivedTasks2 := make(chan uuid.UUID, 3) - receivedTasks3 := make(chan uuid.UUID, 3) - - handler1 := func(id uuid.UUID, isCallNotifier bool) { receivedTasks1 <- id } - handler2 := func(id uuid.UUID, isCallNotifier bool) { receivedTasks2 <- id } - handler3 := func(id uuid.UUID, isCallNotifier bool) { receivedTasks3 <- id } - - err := registry1.SubscribeNodeForTasksAssignment(node1.ID.String(), handler1) - assert.NoError(t, err) - - err = registry2.SubscribeNodeForTasksAssignment(node2.ID.String(), handler2) - assert.NoError(t, err) - - err = registry3.SubscribeNodeForTasksAssignment(node3.ID.String(), handler3) - assert.NoError(t, err) - - time.Sleep(100 * time.Millisecond) - - submitRegistry := createTestRegistry() - err = submitRegistry.AssignTaskToNode(node1.ID.String(), taskID1, true) - assert.NoError(t, err) - - err = submitRegistry.AssignTaskToNode(node2.ID.String(), taskID2, false) - assert.NoError(t, err) - - err = submitRegistry.AssignTaskToNode(node3.ID.String(), taskID3, true) - assert.NoError(t, err) - - select { - case received := <-receivedTasks1: - assert.Equal(t, taskID1, received) - case <-time.After(2 * time.Second): - t.Fatal("Node 1 timeout waiting for task message") - } - - select { - case received := <-receivedTasks2: - assert.Equal(t, taskID2, received) - case <-time.After(2 * time.Second): - t.Fatal("Node 2 timeout waiting for task message") - } - - select { - case received := <-receivedTasks3: - assert.Equal(t, taskID3, received) - case <-time.After(2 * time.Second): - t.Fatal("Node 3 timeout waiting for task message") - } - - select { - case <-receivedTasks1: - t.Fatal("Node 1 should not receive additional tasks") - case <-time.After(300 * time.Millisecond): - } - - select { - case <-receivedTasks2: - t.Fatal("Node 2 should not receive additional tasks") - case <-time.After(300 * time.Millisecond): - } - - select { - case <-receivedTasks3: - t.Fatal("Node 3 should not receive additional tasks") - case <-time.After(300 * time.Millisecond): - } -} - -func Test_PublishTaskCompletion_PublishesMessageToChannel(t *testing.T) { - cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() - registry := createTestRegistry() - node := createTestTaskNode() - taskID := uuid.New() - - err := registry.PublishTaskCompletion(node.ID.String(), taskID) - assert.NoError(t, err) -} - -func Test_SubscribeForTasksCompletions_ReceivesCompletedTasks(t *testing.T) { - cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() - registry := createTestRegistry() - node := createTestTaskNode() - taskID := uuid.New() - defer registry.UnsubscribeForTasksCompletions() - - receivedTaskID := make(chan uuid.UUID, 1) - receivedNodeID := make(chan string, 1) - handler := func(nodeID string, taskID uuid.UUID) { - receivedNodeID <- nodeID - receivedTaskID <- taskID - } - - err := registry.SubscribeForTasksCompletions(handler) - assert.NoError(t, err) - - time.Sleep(100 * time.Millisecond) - - err = registry.PublishTaskCompletion(node.ID.String(), taskID) - assert.NoError(t, err) - - select { - case receivedNode := <-receivedNodeID: - assert.Equal(t, node.ID.String(), receivedNode) - case <-time.After(2 * time.Second): - t.Fatal("Timeout waiting for node ID") - } - - select { - case received := <-receivedTaskID: - assert.Equal(t, taskID, received) - case <-time.After(2 * time.Second): - t.Fatal("Timeout waiting for task completion message") - } -} - -func Test_SubscribeForTasksCompletions_ParsesJsonCorrectly(t *testing.T) { - cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() - registry := createTestRegistry() - node := createTestTaskNode() - taskID1 := uuid.New() - taskID2 := uuid.New() - defer registry.UnsubscribeForTasksCompletions() - - receivedTasks := make(chan uuid.UUID, 2) - handler := func(nodeID string, taskID uuid.UUID) { - receivedTasks <- taskID - } - - err := registry.SubscribeForTasksCompletions(handler) - assert.NoError(t, err) - - time.Sleep(100 * time.Millisecond) - - err = registry.PublishTaskCompletion(node.ID.String(), taskID1) - assert.NoError(t, err) - - err = registry.PublishTaskCompletion(node.ID.String(), taskID2) - assert.NoError(t, err) - - received1 := <-receivedTasks - received2 := <-receivedTasks - - receivedIDs := []uuid.UUID{received1, received2} - assert.Contains(t, receivedIDs, taskID1) - assert.Contains(t, receivedIDs, taskID2) -} - -func Test_SubscribeForTasksCompletions_HandlesInvalidJson(t *testing.T) { - cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() - registry := createTestRegistry() - defer registry.UnsubscribeForTasksCompletions() - - receivedTaskID := make(chan uuid.UUID, 1) - handler := func(nodeID string, taskID uuid.UUID) { - receivedTaskID <- taskID - } - - err := registry.SubscribeForTasksCompletions(handler) - assert.NoError(t, err) - - time.Sleep(100 * time.Millisecond) - - ctx := context.Background() - err = registry.pubsubCompletions.Publish(ctx, "backup:completion", "invalid json") - assert.NoError(t, err) - - select { - case <-receivedTaskID: - t.Fatal("Should not receive task for invalid JSON") - case <-time.After(500 * time.Millisecond): - } -} - -func Test_UnsubscribeForTasksCompletions_StopsReceivingMessages(t *testing.T) { - cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() - registry := createTestRegistry() - node := createTestTaskNode() - taskID1 := uuid.New() - taskID2 := uuid.New() - - receivedTaskID := make(chan uuid.UUID, 2) - handler := func(nodeID string, taskID uuid.UUID) { - receivedTaskID <- taskID - } - - err := registry.SubscribeForTasksCompletions(handler) - assert.NoError(t, err) - - time.Sleep(100 * time.Millisecond) - - err = registry.PublishTaskCompletion(node.ID.String(), taskID1) - assert.NoError(t, err) - - received := <-receivedTaskID - assert.Equal(t, taskID1, received) - - err = registry.UnsubscribeForTasksCompletions() - assert.NoError(t, err) - - time.Sleep(100 * time.Millisecond) - - err = registry.PublishTaskCompletion(node.ID.String(), taskID2) - assert.NoError(t, err) - - select { - case <-receivedTaskID: - t.Fatal("Should not receive task after unsubscribe") - case <-time.After(500 * time.Millisecond): - } -} - -func Test_SubscribeForTasksCompletions_WhenAlreadySubscribed_ReturnsError(t *testing.T) { - cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() - registry := createTestRegistry() - defer registry.UnsubscribeForTasksCompletions() - - handler := func(nodeID string, taskID uuid.UUID) {} - - err := registry.SubscribeForTasksCompletions(handler) - assert.NoError(t, err) - - err = registry.SubscribeForTasksCompletions(handler) - assert.Error(t, err) - assert.Contains(t, err.Error(), "already subscribed") -} - -func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) { - cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() - registry1 := createTestRegistry() - registry2 := createTestRegistry() - registry3 := createTestRegistry() - - node1 := createTestTaskNode() - node2 := createTestTaskNode() - node3 := createTestTaskNode() - - taskID1 := uuid.New() - taskID2 := uuid.New() - taskID3 := uuid.New() - - defer registry1.UnsubscribeForTasksCompletions() - defer registry2.UnsubscribeForTasksCompletions() - defer registry3.UnsubscribeForTasksCompletions() - defer cleanupTestNode(registry1, node1) - defer cleanupTestNode(registry1, node2) - defer cleanupTestNode(registry1, node3) - - receivedTasks1 := make(chan uuid.UUID, 3) - receivedTasks2 := make(chan uuid.UUID, 3) - receivedTasks3 := make(chan uuid.UUID, 3) - - handler1 := func(nodeID string, taskID uuid.UUID) { receivedTasks1 <- taskID } - handler2 := func(nodeID string, taskID uuid.UUID) { receivedTasks2 <- taskID } - handler3 := func(nodeID string, taskID uuid.UUID) { receivedTasks3 <- taskID } - - err := registry1.SubscribeForTasksCompletions(handler1) - assert.NoError(t, err) - - err = registry2.SubscribeForTasksCompletions(handler2) - assert.NoError(t, err) - - err = registry3.SubscribeForTasksCompletions(handler3) - assert.NoError(t, err) - - time.Sleep(100 * time.Millisecond) - - publishRegistry := createTestRegistry() - err = publishRegistry.PublishTaskCompletion(node1.ID.String(), taskID1) - assert.NoError(t, err) - - err = publishRegistry.PublishTaskCompletion(node2.ID.String(), taskID2) - assert.NoError(t, err) - - err = publishRegistry.PublishTaskCompletion(node3.ID.String(), taskID3) - assert.NoError(t, err) - - receivedAll1 := []uuid.UUID{} - receivedAll2 := []uuid.UUID{} - receivedAll3 := []uuid.UUID{} - - for i := 0; i < 3; i++ { - select { - case received := <-receivedTasks1: - receivedAll1 = append(receivedAll1, received) - case <-time.After(2 * time.Second): - t.Fatal("Subscriber 1 timeout waiting for completion message") - } - } - - for i := 0; i < 3; i++ { - select { - case received := <-receivedTasks2: - receivedAll2 = append(receivedAll2, received) - case <-time.After(2 * time.Second): - t.Fatal("Subscriber 2 timeout waiting for completion message") - } - } - - for i := 0; i < 3; i++ { - select { - case received := <-receivedTasks3: - receivedAll3 = append(receivedAll3, received) - case <-time.After(2 * time.Second): - t.Fatal("Subscriber 3 timeout waiting for completion message") - } - } - - assert.Contains(t, receivedAll1, taskID1) - assert.Contains(t, receivedAll1, taskID2) - assert.Contains(t, receivedAll1, taskID3) - - assert.Contains(t, receivedAll2, taskID1) - assert.Contains(t, receivedAll2, taskID2) - assert.Contains(t, receivedAll2, taskID3) - - assert.Contains(t, receivedAll3, taskID1) - assert.Contains(t, receivedAll3, taskID2) - assert.Contains(t, receivedAll3, taskID3) -} - func Test_GetAvailableNodes_ExcludesStaleNodesFromCache(t *testing.T) { cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() registry := createTestRegistry() - node1 := createTestTaskNode() - node2 := createTestTaskNode() - node3 := createTestTaskNode() + node1 := createTestRestoreNode() + node2 := createTestRestoreNode() + node3 := createTestRestoreNode() defer cleanupTestNode(registry, node1) defer cleanupTestNode(registry, node2) defer cleanupTestNode(registry, node3) @@ -984,7 +409,7 @@ func Test_GetAvailableNodes_ExcludesStaleNodesFromCache(t *testing.T) { data, err := result.AsBytes() assert.NoError(t, err) - var node TaskNode + var node RestoreNode err = json.Unmarshal(data, &node) assert.NoError(t, err) @@ -1013,13 +438,12 @@ func Test_GetAvailableNodes_ExcludesStaleNodesFromCache(t *testing.T) { assert.True(t, nodeIDs[node3.ID]) } -func Test_GetNodesStats_ExcludesStaleNodesFromCache(t *testing.T) { +func Test_GetRestoreNodesStats_ExcludesStaleNodesFromCache(t *testing.T) { cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() registry := createTestRegistry() - node1 := createTestTaskNode() - node2 := createTestTaskNode() - node3 := createTestTaskNode() + node1 := createTestRestoreNode() + node2 := createTestRestoreNode() + node3 := createTestRestoreNode() defer cleanupTestNode(registry, node1) defer cleanupTestNode(registry, node2) defer cleanupTestNode(registry, node3) @@ -1031,11 +455,11 @@ func Test_GetNodesStats_ExcludesStaleNodesFromCache(t *testing.T) { err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3) assert.NoError(t, err) - err = registry.IncrementTasksInProgress(node1.ID.String()) + err = registry.IncrementRestoresInProgress(node1.ID) assert.NoError(t, err) - err = registry.IncrementTasksInProgress(node2.ID.String()) + err = registry.IncrementRestoresInProgress(node2.ID) assert.NoError(t, err) - err = registry.IncrementTasksInProgress(node3.ID.String()) + err = registry.IncrementRestoresInProgress(node3.ID) assert.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), registry.timeout) @@ -1048,7 +472,7 @@ func Test_GetNodesStats_ExcludesStaleNodesFromCache(t *testing.T) { data, err := result.AsBytes() assert.NoError(t, err) - var node TaskNode + var node RestoreNode err = json.Unmarshal(data, &node) assert.NoError(t, err) @@ -1064,13 +488,13 @@ func Test_GetNodesStats_ExcludesStaleNodesFromCache(t *testing.T) { ) assert.NoError(t, setResult.Error()) - stats, err := registry.GetNodesStats() + stats, err := registry.GetRestoreNodesStats() assert.NoError(t, err) assert.Len(t, stats, 2) statsMap := make(map[uuid.UUID]int) for _, stat := range stats { - statsMap[stat.ID] = stat.ActiveTasks + statsMap[stat.ID] = stat.ActiveRestores } assert.Equal(t, 1, statsMap[node1.ID]) @@ -1081,11 +505,10 @@ func Test_GetNodesStats_ExcludesStaleNodesFromCache(t *testing.T) { func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) { cache_utils.ClearAllCache() - defer cache_utils.ClearAllCache() registry := createTestRegistry() - node1 := createTestTaskNode() - node2 := createTestTaskNode() + node1 := createTestRestoreNode() + node2 := createTestRestoreNode() defer cleanupTestNode(registry, node1) defer cleanupTestNode(registry, node2) @@ -1094,9 +517,9 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) { err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2) assert.NoError(t, err) - err = registry.IncrementTasksInProgress(node1.ID.String()) + err = registry.IncrementRestoresInProgress(node1.ID) assert.NoError(t, err) - err = registry.IncrementTasksInProgress(node2.ID.String()) + err = registry.IncrementRestoresInProgress(node2.ID) assert.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), registry.timeout) @@ -1109,7 +532,7 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) { data, err := result.AsBytes() assert.NoError(t, err) - var node TaskNode + var node RestoreNode err = json.Unmarshal(data, &node) assert.NoError(t, err) @@ -1137,9 +560,9 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) { counterKey := fmt.Sprintf( "%s%s%s", - nodeActiveTasksPrefix, + nodeActiveRestoresPrefix, node2.ID.String(), - nodeActiveTasksSuffix, + nodeActiveRestoresSuffix, ) counterCtx, counterCancel := context.WithTimeout(context.Background(), registry.timeout) defer counterCancel() @@ -1163,8 +586,547 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) { assert.Len(t, nodes, 1) assert.Equal(t, node1.ID, nodes[0].ID) - stats, err := registry.GetNodesStats() + stats, err := registry.GetRestoreNodesStats() assert.NoError(t, err) assert.Len(t, stats, 1) assert.Equal(t, node1.ID, stats[0].ID) } + +func createTestRegistry() *RestoreNodesRegistry { + return &RestoreNodesRegistry{ + cache_utils.GetValkeyClient(), + logger.GetLogger(), + cache_utils.DefaultCacheTimeout, + cache_utils.NewPubSubManager(), + cache_utils.NewPubSubManager(), + } +} + +func createTestRestoreNode() RestoreNode { + return RestoreNode{ + ID: uuid.New(), + ThroughputMBs: 100, + LastHeartbeat: time.Now().UTC(), + } +} + +func cleanupTestNode(registry *RestoreNodesRegistry, node RestoreNode) { + registry.UnregisterNodeFromRegistry(node) +} + +func Test_AssignRestoreToNode_PublishesJsonMessageToChannel(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node := createTestRestoreNode() + restoreID := uuid.New() + + err := registry.AssignRestoreToNode(node.ID, restoreID, true) + assert.NoError(t, err) +} + +func Test_SubscribeNodeForRestoresAssignment_ReceivesSubmittedRestoresForMatchingNode( + t *testing.T, +) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node := createTestRestoreNode() + restoreID := uuid.New() + defer registry.UnsubscribeNodeForRestoresAssignments() + + receivedRestoreID := make(chan uuid.UUID, 1) + handler := func(id uuid.UUID, isCallNotifier bool) { + receivedRestoreID <- id + } + + err := registry.SubscribeNodeForRestoresAssignment(node.ID, handler) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + err = registry.AssignRestoreToNode(node.ID, restoreID, true) + assert.NoError(t, err) + + select { + case received := <-receivedRestoreID: + assert.Equal(t, restoreID, received) + case <-time.After(2 * time.Second): + t.Fatal("Timeout waiting for restore message") + } +} + +func Test_SubscribeNodeForRestoresAssignment_FiltersOutRestoresForDifferentNode(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node1 := createTestRestoreNode() + node2 := createTestRestoreNode() + restoreID := uuid.New() + defer registry.UnsubscribeNodeForRestoresAssignments() + + receivedRestoreID := make(chan uuid.UUID, 1) + handler := func(id uuid.UUID, isCallNotifier bool) { + receivedRestoreID <- id + } + + err := registry.SubscribeNodeForRestoresAssignment(node1.ID, handler) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + err = registry.AssignRestoreToNode(node2.ID, restoreID, false) + assert.NoError(t, err) + + select { + case <-receivedRestoreID: + t.Fatal("Should not receive restore for different node") + case <-time.After(500 * time.Millisecond): + } +} + +func Test_SubscribeNodeForRestoresAssignment_ParsesJsonAndRestoreIdCorrectly(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node := createTestRestoreNode() + restoreID1 := uuid.New() + restoreID2 := uuid.New() + defer registry.UnsubscribeNodeForRestoresAssignments() + + receivedRestores := make(chan uuid.UUID, 2) + handler := func(id uuid.UUID, isCallNotifier bool) { + receivedRestores <- id + } + + err := registry.SubscribeNodeForRestoresAssignment(node.ID, handler) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + err = registry.AssignRestoreToNode(node.ID, restoreID1, true) + assert.NoError(t, err) + + err = registry.AssignRestoreToNode(node.ID, restoreID2, false) + assert.NoError(t, err) + + received1 := <-receivedRestores + received2 := <-receivedRestores + + receivedIDs := []uuid.UUID{received1, received2} + assert.Contains(t, receivedIDs, restoreID1) + assert.Contains(t, receivedIDs, restoreID2) +} + +func Test_SubscribeNodeForRestoresAssignment_HandlesInvalidJson(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node := createTestRestoreNode() + defer registry.UnsubscribeNodeForRestoresAssignments() + + receivedRestoreID := make(chan uuid.UUID, 1) + handler := func(id uuid.UUID, isCallNotifier bool) { + receivedRestoreID <- id + } + + err := registry.SubscribeNodeForRestoresAssignment(node.ID, handler) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + ctx := context.Background() + err = registry.pubsubRestores.Publish(ctx, "restore:submit", "invalid json") + assert.NoError(t, err) + + select { + case <-receivedRestoreID: + t.Fatal("Should not receive restore for invalid JSON") + case <-time.After(500 * time.Millisecond): + } +} + +func Test_UnsubscribeNodeForRestoresAssignments_StopsReceivingMessages(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node := createTestRestoreNode() + restoreID1 := uuid.New() + restoreID2 := uuid.New() + + receivedRestoreID := make(chan uuid.UUID, 2) + handler := func(id uuid.UUID, isCallNotifier bool) { + receivedRestoreID <- id + } + + err := registry.SubscribeNodeForRestoresAssignment(node.ID, handler) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + err = registry.AssignRestoreToNode(node.ID, restoreID1, true) + assert.NoError(t, err) + + received := <-receivedRestoreID + assert.Equal(t, restoreID1, received) + + err = registry.UnsubscribeNodeForRestoresAssignments() + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + err = registry.AssignRestoreToNode(node.ID, restoreID2, false) + assert.NoError(t, err) + + select { + case <-receivedRestoreID: + t.Fatal("Should not receive restore after unsubscribe") + case <-time.After(500 * time.Millisecond): + } +} + +func Test_SubscribeNodeForRestoresAssignment_WhenAlreadySubscribed_ReturnsError(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node := createTestRestoreNode() + defer registry.UnsubscribeNodeForRestoresAssignments() + + handler := func(id uuid.UUID, isCallNotifier bool) {} + + err := registry.SubscribeNodeForRestoresAssignment(node.ID, handler) + assert.NoError(t, err) + + err = registry.SubscribeNodeForRestoresAssignment(node.ID, handler) + assert.Error(t, err) + assert.Contains(t, err.Error(), "already subscribed") +} + +func Test_MultipleNodes_EachReceivesOnlyTheirRestores(t *testing.T) { + cache_utils.ClearAllCache() + registry1 := createTestRegistry() + registry2 := createTestRegistry() + registry3 := createTestRegistry() + + node1 := createTestRestoreNode() + node2 := createTestRestoreNode() + node3 := createTestRestoreNode() + + restoreID1 := uuid.New() + restoreID2 := uuid.New() + restoreID3 := uuid.New() + + defer registry1.UnsubscribeNodeForRestoresAssignments() + defer registry2.UnsubscribeNodeForRestoresAssignments() + defer registry3.UnsubscribeNodeForRestoresAssignments() + + receivedRestores1 := make(chan uuid.UUID, 3) + receivedRestores2 := make(chan uuid.UUID, 3) + receivedRestores3 := make(chan uuid.UUID, 3) + + handler1 := func(id uuid.UUID, isCallNotifier bool) { receivedRestores1 <- id } + handler2 := func(id uuid.UUID, isCallNotifier bool) { receivedRestores2 <- id } + handler3 := func(id uuid.UUID, isCallNotifier bool) { receivedRestores3 <- id } + + err := registry1.SubscribeNodeForRestoresAssignment(node1.ID, handler1) + assert.NoError(t, err) + + err = registry2.SubscribeNodeForRestoresAssignment(node2.ID, handler2) + assert.NoError(t, err) + + err = registry3.SubscribeNodeForRestoresAssignment(node3.ID, handler3) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + submitRegistry := createTestRegistry() + err = submitRegistry.AssignRestoreToNode(node1.ID, restoreID1, true) + assert.NoError(t, err) + + err = submitRegistry.AssignRestoreToNode(node2.ID, restoreID2, false) + assert.NoError(t, err) + + err = submitRegistry.AssignRestoreToNode(node3.ID, restoreID3, true) + assert.NoError(t, err) + + select { + case received := <-receivedRestores1: + assert.Equal(t, restoreID1, received) + case <-time.After(2 * time.Second): + t.Fatal("Node 1 timeout waiting for restore message") + } + + select { + case received := <-receivedRestores2: + assert.Equal(t, restoreID2, received) + case <-time.After(2 * time.Second): + t.Fatal("Node 2 timeout waiting for restore message") + } + + select { + case received := <-receivedRestores3: + assert.Equal(t, restoreID3, received) + case <-time.After(2 * time.Second): + t.Fatal("Node 3 timeout waiting for restore message") + } + + select { + case <-receivedRestores1: + t.Fatal("Node 1 should not receive additional restores") + case <-time.After(300 * time.Millisecond): + } + + select { + case <-receivedRestores2: + t.Fatal("Node 2 should not receive additional restores") + case <-time.After(300 * time.Millisecond): + } + + select { + case <-receivedRestores3: + t.Fatal("Node 3 should not receive additional restores") + case <-time.After(300 * time.Millisecond): + } +} + +func Test_PublishRestoreCompletion_PublishesMessageToChannel(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node := createTestRestoreNode() + restoreID := uuid.New() + + err := registry.PublishRestoreCompletion(node.ID, restoreID) + assert.NoError(t, err) +} + +func Test_SubscribeForRestoresCompletions_ReceivesCompletedRestores(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node := createTestRestoreNode() + restoreID := uuid.New() + defer registry.UnsubscribeForRestoresCompletions() + + receivedRestoreID := make(chan uuid.UUID, 1) + receivedNodeID := make(chan uuid.UUID, 1) + handler := func(nodeID uuid.UUID, restoreID uuid.UUID) { + receivedNodeID <- nodeID + receivedRestoreID <- restoreID + } + + err := registry.SubscribeForRestoresCompletions(handler) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + err = registry.PublishRestoreCompletion(node.ID, restoreID) + assert.NoError(t, err) + + select { + case receivedNode := <-receivedNodeID: + assert.Equal(t, node.ID, receivedNode) + case <-time.After(2 * time.Second): + t.Fatal("Timeout waiting for node ID") + } + + select { + case received := <-receivedRestoreID: + assert.Equal(t, restoreID, received) + case <-time.After(2 * time.Second): + t.Fatal("Timeout waiting for restore completion message") + } +} + +func Test_SubscribeForRestoresCompletions_ParsesJsonCorrectly(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node := createTestRestoreNode() + restoreID1 := uuid.New() + restoreID2 := uuid.New() + defer registry.UnsubscribeForRestoresCompletions() + + receivedRestores := make(chan uuid.UUID, 2) + handler := func(nodeID uuid.UUID, restoreID uuid.UUID) { + receivedRestores <- restoreID + } + + err := registry.SubscribeForRestoresCompletions(handler) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + err = registry.PublishRestoreCompletion(node.ID, restoreID1) + assert.NoError(t, err) + + err = registry.PublishRestoreCompletion(node.ID, restoreID2) + assert.NoError(t, err) + + received1 := <-receivedRestores + received2 := <-receivedRestores + + receivedIDs := []uuid.UUID{received1, received2} + assert.Contains(t, receivedIDs, restoreID1) + assert.Contains(t, receivedIDs, restoreID2) +} + +func Test_SubscribeForRestoresCompletions_HandlesInvalidJson(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + defer registry.UnsubscribeForRestoresCompletions() + + receivedRestoreID := make(chan uuid.UUID, 1) + handler := func(nodeID uuid.UUID, restoreID uuid.UUID) { + receivedRestoreID <- restoreID + } + + err := registry.SubscribeForRestoresCompletions(handler) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + ctx := context.Background() + err = registry.pubsubCompletions.Publish(ctx, "restore:completion", "invalid json") + assert.NoError(t, err) + + select { + case <-receivedRestoreID: + t.Fatal("Should not receive restore for invalid JSON") + case <-time.After(500 * time.Millisecond): + } +} + +func Test_UnsubscribeForRestoresCompletions_StopsReceivingMessages(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + node := createTestRestoreNode() + restoreID1 := uuid.New() + restoreID2 := uuid.New() + + receivedRestoreID := make(chan uuid.UUID, 2) + handler := func(nodeID uuid.UUID, restoreID uuid.UUID) { + receivedRestoreID <- restoreID + } + + err := registry.SubscribeForRestoresCompletions(handler) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + err = registry.PublishRestoreCompletion(node.ID, restoreID1) + assert.NoError(t, err) + + received := <-receivedRestoreID + assert.Equal(t, restoreID1, received) + + err = registry.UnsubscribeForRestoresCompletions() + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + err = registry.PublishRestoreCompletion(node.ID, restoreID2) + assert.NoError(t, err) + + select { + case <-receivedRestoreID: + t.Fatal("Should not receive restore after unsubscribe") + case <-time.After(500 * time.Millisecond): + } +} + +func Test_SubscribeForRestoresCompletions_WhenAlreadySubscribed_ReturnsError(t *testing.T) { + cache_utils.ClearAllCache() + registry := createTestRegistry() + defer registry.UnsubscribeForRestoresCompletions() + + handler := func(nodeID uuid.UUID, restoreID uuid.UUID) {} + + err := registry.SubscribeForRestoresCompletions(handler) + assert.NoError(t, err) + + err = registry.SubscribeForRestoresCompletions(handler) + assert.Error(t, err) + assert.Contains(t, err.Error(), "already subscribed") +} + +func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) { + cache_utils.ClearAllCache() + registry1 := createTestRegistry() + registry2 := createTestRegistry() + registry3 := createTestRegistry() + + node1 := createTestRestoreNode() + node2 := createTestRestoreNode() + node3 := createTestRestoreNode() + + restoreID1 := uuid.New() + restoreID2 := uuid.New() + restoreID3 := uuid.New() + + defer registry1.UnsubscribeForRestoresCompletions() + defer registry2.UnsubscribeForRestoresCompletions() + defer registry3.UnsubscribeForRestoresCompletions() + + receivedRestores1 := make(chan uuid.UUID, 3) + receivedRestores2 := make(chan uuid.UUID, 3) + receivedRestores3 := make(chan uuid.UUID, 3) + + handler1 := func(nodeID uuid.UUID, restoreID uuid.UUID) { receivedRestores1 <- restoreID } + handler2 := func(nodeID uuid.UUID, restoreID uuid.UUID) { receivedRestores2 <- restoreID } + handler3 := func(nodeID uuid.UUID, restoreID uuid.UUID) { receivedRestores3 <- restoreID } + + err := registry1.SubscribeForRestoresCompletions(handler1) + assert.NoError(t, err) + + err = registry2.SubscribeForRestoresCompletions(handler2) + assert.NoError(t, err) + + err = registry3.SubscribeForRestoresCompletions(handler3) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + publishRegistry := createTestRegistry() + err = publishRegistry.PublishRestoreCompletion(node1.ID, restoreID1) + assert.NoError(t, err) + + err = publishRegistry.PublishRestoreCompletion(node2.ID, restoreID2) + assert.NoError(t, err) + + err = publishRegistry.PublishRestoreCompletion(node3.ID, restoreID3) + assert.NoError(t, err) + + receivedAll1 := []uuid.UUID{} + receivedAll2 := []uuid.UUID{} + receivedAll3 := []uuid.UUID{} + + for i := 0; i < 3; i++ { + select { + case received := <-receivedRestores1: + receivedAll1 = append(receivedAll1, received) + case <-time.After(2 * time.Second): + t.Fatal("Subscriber 1 timeout waiting for completion message") + } + } + + for i := 0; i < 3; i++ { + select { + case received := <-receivedRestores2: + receivedAll2 = append(receivedAll2, received) + case <-time.After(2 * time.Second): + t.Fatal("Subscriber 2 timeout waiting for completion message") + } + } + + for i := 0; i < 3; i++ { + select { + case received := <-receivedRestores3: + receivedAll3 = append(receivedAll3, received) + case <-time.After(2 * time.Second): + t.Fatal("Subscriber 3 timeout waiting for completion message") + } + } + + assert.Contains(t, receivedAll1, restoreID1) + assert.Contains(t, receivedAll1, restoreID2) + assert.Contains(t, receivedAll1, restoreID3) + + assert.Contains(t, receivedAll2, restoreID1) + assert.Contains(t, receivedAll2, restoreID2) + assert.Contains(t, receivedAll2, restoreID3) + + assert.Contains(t, receivedAll3, restoreID1) + assert.Contains(t, receivedAll3, restoreID2) + assert.Contains(t, receivedAll3, restoreID3) +} diff --git a/backend/internal/features/restores/restoring/restorer.go b/backend/internal/features/restores/restoring/restorer.go new file mode 100644 index 0000000..ba9a059 --- /dev/null +++ b/backend/internal/features/restores/restoring/restorer.go @@ -0,0 +1,262 @@ +package restoring + +import ( + "context" + "fmt" + "log/slog" + "time" + + "github.com/google/uuid" + + "databasus-backend/internal/config" + "databasus-backend/internal/features/backups/backups" + backups_config "databasus-backend/internal/features/backups/config" + "databasus-backend/internal/features/databases" + restores_core "databasus-backend/internal/features/restores/core" + "databasus-backend/internal/features/storages" + cache_utils "databasus-backend/internal/util/cache" + util_encryption "databasus-backend/internal/util/encryption" +) + +const ( + heartbeatTickerInterval = 15 * time.Second + restorerHealthcheckThreshold = 5 * time.Minute +) + +type RestorerNode struct { + nodeID uuid.UUID + + databaseService *databases.DatabaseService + backupService *backups.BackupService + fieldEncryptor util_encryption.FieldEncryptor + restoreRepository *restores_core.RestoreRepository + backupConfigService *backups_config.BackupConfigService + storageService *storages.StorageService + restoreNodesRegistry *RestoreNodesRegistry + logger *slog.Logger + restoreBackupUsecase restores_core.RestoreBackupUsecase + cacheUtil *cache_utils.CacheUtil[RestoreDatabaseCache] + + lastHeartbeat time.Time +} + +func (n *RestorerNode) Run(ctx context.Context) { + n.lastHeartbeat = time.Now().UTC() + + throughputMBs := config.GetEnv().NodeNetworkThroughputMBs + + restoreNode := RestoreNode{ + ID: n.nodeID, + ThroughputMBs: throughputMBs, + } + + if err := n.restoreNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), restoreNode); err != nil { + n.logger.Error("Failed to register node in registry", "error", err) + panic(err) + } + + restoreHandler := func(restoreID uuid.UUID, isCallNotifier bool) { + n.MakeRestore(restoreID) + if err := n.restoreNodesRegistry.PublishRestoreCompletion(n.nodeID, restoreID); err != nil { + n.logger.Error( + "Failed to publish restore completion", + "error", + err, + "restoreID", + restoreID, + ) + } + } + + err := n.restoreNodesRegistry.SubscribeNodeForRestoresAssignment( + n.nodeID, + restoreHandler, + ) + if err != nil { + n.logger.Error("Failed to subscribe to restore assignments", "error", err) + panic(err) + } + defer func() { + if err := n.restoreNodesRegistry.UnsubscribeNodeForRestoresAssignments(); err != nil { + n.logger.Error("Failed to unsubscribe from restore assignments", "error", err) + } + }() + + ticker := time.NewTicker(heartbeatTickerInterval) + defer ticker.Stop() + + n.logger.Info("Restore node started", "nodeID", n.nodeID, "throughput", throughputMBs) + + for { + select { + case <-ctx.Done(): + n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID) + + if err := n.restoreNodesRegistry.UnregisterNodeFromRegistry(restoreNode); err != nil { + n.logger.Error("Failed to unregister node from registry", "error", err) + } + + return + case <-ticker.C: + n.sendHeartbeat(&restoreNode) + } + } +} + +func (n *RestorerNode) IsRestorerRunning() bool { + return n.lastHeartbeat.After(time.Now().UTC().Add(-restorerHealthcheckThreshold)) +} + +func (n *RestorerNode) MakeRestore(restoreID uuid.UUID) { + // Get and delete cached DB credentials atomically + dbCache := n.cacheUtil.GetAndDelete(restoreID.String()) + + if dbCache == nil { + // Cache miss - fail immediately + restore, err := n.restoreRepository.FindByID(restoreID) + if err != nil { + n.logger.Error( + "Failed to get restore by ID after cache miss", + "restoreId", + restoreID, + "error", + err, + ) + return + } + + errMsg := "Database credentials expired or missing from cache (most likely due to instance restart)" + restore.FailMessage = &errMsg + restore.Status = restores_core.RestoreStatusFailed + + if err := n.restoreRepository.Save(restore); err != nil { + n.logger.Error("Failed to save restore after cache miss", "error", err) + } + + n.logger.Error("Restore failed: cache miss", "restoreId", restoreID) + return + } + + restore, err := n.restoreRepository.FindByID(restoreID) + if err != nil { + n.logger.Error("Failed to get restore by ID", "restoreId", restoreID, "error", err) + return + } + + backup, err := n.backupService.GetBackup(restore.BackupID) + if err != nil { + n.logger.Error("Failed to get backup by ID", "backupId", restore.BackupID, "error", err) + return + } + + databaseID := backup.DatabaseID + + database, err := n.databaseService.GetDatabaseByID(databaseID) + if err != nil { + n.logger.Error("Failed to get database by ID", "databaseId", databaseID, "error", err) + return + } + + backupConfig, err := n.backupConfigService.GetBackupConfigByDbId(databaseID) + if err != nil { + n.logger.Error("Failed to get backup config by database ID", "error", err) + return + } + + if backupConfig.StorageID == nil { + n.logger.Error("Backup config storage ID is not defined") + return + } + + storage, err := n.storageService.GetStorageByID(*backupConfig.StorageID) + if err != nil { + n.logger.Error("Failed to get storage by ID", "error", err) + return + } + + start := time.Now().UTC() + + // Create restoring database from cached credentials + restoringToDB := &databases.Database{ + Type: database.Type, + Postgresql: dbCache.PostgresqlDatabase, + Mysql: dbCache.MysqlDatabase, + Mariadb: dbCache.MariadbDatabase, + Mongodb: dbCache.MongodbDatabase, + } + + if err := restoringToDB.PopulateDbData(n.logger, n.fieldEncryptor); err != nil { + errMsg := fmt.Sprintf("failed to auto-detect database data: %v", err) + restore.FailMessage = &errMsg + restore.Status = restores_core.RestoreStatusFailed + restore.RestoreDurationMs = time.Since(start).Milliseconds() + + if err := n.restoreRepository.Save(restore); err != nil { + n.logger.Error("Failed to save restore", "error", err) + } + + return + } + + isExcludeExtensions := false + if dbCache.PostgresqlDatabase != nil { + isExcludeExtensions = dbCache.PostgresqlDatabase.IsExcludeExtensions + } + + err = n.restoreBackupUsecase.Execute( + backupConfig, + *restore, + database, + restoringToDB, + backup, + storage, + isExcludeExtensions, + ) + + if err != nil { + errMsg := err.Error() + + n.logger.Error("Restore execution failed", + "restoreId", restore.ID, + "backupId", backup.ID, + "databaseId", databaseID, + "databaseType", database.Type, + "storageId", storage.ID, + "storageType", storage.Type, + "error", err, + "errorMessage", errMsg, + ) + + restore.FailMessage = &errMsg + restore.Status = restores_core.RestoreStatusFailed + restore.RestoreDurationMs = time.Since(start).Milliseconds() + + if err := n.restoreRepository.Save(restore); err != nil { + n.logger.Error("Failed to save restore", "error", err) + } + + return + } + + restore.Status = restores_core.RestoreStatusCompleted + restore.RestoreDurationMs = time.Since(start).Milliseconds() + + if err := n.restoreRepository.Save(restore); err != nil { + n.logger.Error("Failed to save restore", "error", err) + return + } + + n.logger.Info( + "Restore completed successfully", + "restoreId", restore.ID, + "backupId", backup.ID, + "durationMs", restore.RestoreDurationMs, + ) +} + +func (n *RestorerNode) sendHeartbeat(restoreNode *RestoreNode) { + n.lastHeartbeat = time.Now().UTC() + if err := n.restoreNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *restoreNode); err != nil { + n.logger.Error("Failed to send heartbeat", "error", err) + } +} diff --git a/backend/internal/features/restores/restoring/restorer_test.go b/backend/internal/features/restores/restoring/restorer_test.go new file mode 100644 index 0000000..2024565 --- /dev/null +++ b/backend/internal/features/restores/restoring/restorer_test.go @@ -0,0 +1,163 @@ +package restoring + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "databasus-backend/internal/features/backups/backups" + backups_core "databasus-backend/internal/features/backups/backups/core" + backups_config "databasus-backend/internal/features/backups/config" + "databasus-backend/internal/features/databases" + "databasus-backend/internal/features/databases/databases/postgresql" + "databasus-backend/internal/features/notifiers" + restores_core "databasus-backend/internal/features/restores/core" + "databasus-backend/internal/features/storages" + users_enums "databasus-backend/internal/features/users/enums" + users_testing "databasus-backend/internal/features/users/testing" + workspaces_testing "databasus-backend/internal/features/workspaces/testing" + cache_utils "databasus-backend/internal/util/cache" +) + +func Test_MakeRestore_WhenCacheMissed_RestoreFails(t *testing.T) { + cache_utils.ClearAllCache() + + user := users_testing.CreateTestUser(users_enums.UserRoleAdmin) + router := CreateTestRouter() + workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router) + storage := storages.CreateTestStorage(workspace.ID) + notifier := notifiers.CreateTestNotifier(workspace.ID) + database := databases.CreateTestDatabase(workspace.ID, storage, notifier) + backups_config.EnableBackupsForTestDatabase(database.ID, storage) + + defer func() { + backupRepo := backups_core.BackupRepository{} + backupsList, _ := backupRepo.FindByDatabaseID(database.ID) + for _, backup := range backupsList { + backupRepo.DeleteByID(backup.ID) + } + + restoreRepo := restores_core.RestoreRepository{} + restoresInProgress, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress) + for _, restore := range restoresInProgress { + restoreRepo.DeleteByID(restore.ID) + } + restoresFailed, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusFailed) + for _, restore := range restoresFailed { + restoreRepo.DeleteByID(restore.ID) + } + + databases.RemoveTestDatabase(database) + time.Sleep(50 * time.Millisecond) + notifiers.RemoveTestNotifier(notifier) + storages.RemoveTestStorage(storage.ID) + workspaces_testing.RemoveTestWorkspace(workspace, router) + + cache_utils.ClearAllCache() + }() + + backup := backups.CreateTestBackup(database.ID, storage.ID) + + // Create restore but DON'T cache DB credentials + // Also don't set embedded DB fields to avoid schema issues + restore := &restores_core.Restore{ + BackupID: backup.ID, + Status: restores_core.RestoreStatusInProgress, + } + err := restoreRepository.Save(restore) + assert.NoError(t, err) + + // Create restorer and execute restore (should fail due to cache miss) + restorerNode := CreateTestRestorerNode() + restorerNode.MakeRestore(restore.ID) + + // Verify restore failed with appropriate error message + updatedRestore, err := restoreRepository.FindByID(restore.ID) + assert.NoError(t, err) + assert.Equal(t, restores_core.RestoreStatusFailed, updatedRestore.Status) + assert.NotNil(t, updatedRestore.FailMessage) + assert.Contains( + t, + *updatedRestore.FailMessage, + "Database credentials expired or missing from cache", + ) +} + +func Test_MakeRestore_WhenTaskStarts_CacheDeletedImmediately(t *testing.T) { + cache_utils.ClearAllCache() + + user := users_testing.CreateTestUser(users_enums.UserRoleAdmin) + router := CreateTestRouter() + workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router) + storage := storages.CreateTestStorage(workspace.ID) + notifier := notifiers.CreateTestNotifier(workspace.ID) + database := databases.CreateTestDatabase(workspace.ID, storage, notifier) + backups_config.EnableBackupsForTestDatabase(database.ID, storage) + + defer func() { + backupRepo := backups_core.BackupRepository{} + backupsList, _ := backupRepo.FindByDatabaseID(database.ID) + for _, backup := range backupsList { + backupRepo.DeleteByID(backup.ID) + } + + restoreRepo := restores_core.RestoreRepository{} + restoresInProgress, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress) + for _, restore := range restoresInProgress { + restoreRepo.DeleteByID(restore.ID) + } + restoresFailed, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusFailed) + for _, restore := range restoresFailed { + restoreRepo.DeleteByID(restore.ID) + } + restoresCompleted, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusCompleted) + for _, restore := range restoresCompleted { + restoreRepo.DeleteByID(restore.ID) + } + + databases.RemoveTestDatabase(database) + time.Sleep(50 * time.Millisecond) + notifiers.RemoveTestNotifier(notifier) + storages.RemoveTestStorage(storage.ID) + workspaces_testing.RemoveTestWorkspace(workspace, router) + + cache_utils.ClearAllCache() + }() + + backup := backups.CreateTestBackup(database.ID, storage.ID) + + // Create restore with cached DB credentials + // Don't set embedded DB fields in the restore model itself + restore := &restores_core.Restore{ + BackupID: backup.ID, + Status: restores_core.RestoreStatusInProgress, + } + err := restoreRepository.Save(restore) + assert.NoError(t, err) + + // Cache DB credentials separately + dbCache := &RestoreDatabaseCache{ + PostgresqlDatabase: &postgresql.PostgresqlDatabase{ + Host: "localhost", + Port: 5432, + Username: "test", + Password: "test", + Database: stringPtr("testdb"), + Version: "16", + }, + } + restoreDatabaseCache.SetWithExpiration(restore.ID.String(), dbCache, 1*time.Hour) + + // Verify cache exists before restore starts + cachedDB := restoreDatabaseCache.Get(restore.ID.String()) + assert.NotNil(t, cachedDB, "Cache should exist before restore starts") + + // Start restore (this will call GetAndDelete) + restorerNode := CreateTestRestorerNode() + restorerNode.MakeRestore(restore.ID) + + // Verify cache was deleted immediately + cachedDBAfter := restoreDatabaseCache.Get(restore.ID.String()) + assert.Nil(t, cachedDBAfter, "Cache should be deleted immediately when task starts") +} diff --git a/backend/internal/features/restores/restoring/scheduler.go b/backend/internal/features/restores/restoring/scheduler.go new file mode 100644 index 0000000..93668ba --- /dev/null +++ b/backend/internal/features/restores/restoring/scheduler.go @@ -0,0 +1,395 @@ +package restoring + +import ( + "context" + "fmt" + "log/slog" + "time" + + "github.com/google/uuid" + + "databasus-backend/internal/config" + "databasus-backend/internal/features/backups/backups" + backups_config "databasus-backend/internal/features/backups/config" + restores_core "databasus-backend/internal/features/restores/core" + "databasus-backend/internal/features/storages" + cache_utils "databasus-backend/internal/util/cache" +) + +const ( + schedulerStartupDelay = 1 * time.Minute + schedulerTickerInterval = 1 * time.Minute + schedulerHealthcheckThreshold = 5 * time.Minute +) + +type RestoresScheduler struct { + restoreRepository *restores_core.RestoreRepository + backupService *backups.BackupService + storageService *storages.StorageService + backupConfigService *backups_config.BackupConfigService + restoreNodesRegistry *RestoreNodesRegistry + lastCheckTime time.Time + logger *slog.Logger + restoreToNodeRelations map[uuid.UUID]RestoreToNodeRelation + restorerNode *RestorerNode + cacheUtil *cache_utils.CacheUtil[RestoreDatabaseCache] + completionSubscriptionID uuid.UUID +} + +func (s *RestoresScheduler) Run(ctx context.Context) { + s.lastCheckTime = time.Now().UTC() + + if config.GetEnv().IsManyNodesMode { + // wait other nodes to start + time.Sleep(schedulerStartupDelay) + } + + if err := s.failRestoresInProgress(); err != nil { + s.logger.Error("Failed to fail restores in progress", "error", err) + panic(err) + } + + err := s.restoreNodesRegistry.SubscribeForRestoresCompletions(s.onRestoreCompleted) + if err != nil { + s.logger.Error("Failed to subscribe to restore completions", "error", err) + panic(err) + } + + defer func() { + if err := s.restoreNodesRegistry.UnsubscribeForRestoresCompletions(); err != nil { + s.logger.Error("Failed to unsubscribe from restore completions", "error", err) + } + }() + + if ctx.Err() != nil { + return + } + + ticker := time.NewTicker(schedulerTickerInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := s.checkDeadNodesAndFailRestores(); err != nil { + s.logger.Error("Failed to check dead nodes and fail restores", "error", err) + } + + s.lastCheckTime = time.Now().UTC() + } + } +} + +func (s *RestoresScheduler) IsSchedulerRunning() bool { + return s.lastCheckTime.After(time.Now().UTC().Add(-schedulerHealthcheckThreshold)) +} + +func (s *RestoresScheduler) failRestoresInProgress() error { + restoresInProgress, err := s.restoreRepository.FindByStatus( + restores_core.RestoreStatusInProgress, + ) + if err != nil { + return err + } + + for _, restore := range restoresInProgress { + failMessage := "Restore failed due to application restart" + restore.FailMessage = &failMessage + restore.Status = restores_core.RestoreStatusFailed + + if err := s.restoreRepository.Save(restore); err != nil { + return err + } + } + + return nil +} + +func (s *RestoresScheduler) StartRestore(restoreID uuid.UUID, dbCache *RestoreDatabaseCache) error { + // If dbCache not provided, try to fetch from DB (for backward compatibility/testing) + if dbCache == nil { + restore, err := s.restoreRepository.FindByID(restoreID) + if err != nil { + s.logger.Error( + "Failed to find restore by ID", + "restoreId", + restoreID, + "error", + err, + ) + return err + } + + // Create cache DTO from restore (may be nil if not in DB) + dbCache = &RestoreDatabaseCache{ + PostgresqlDatabase: restore.PostgresqlDatabase, + MysqlDatabase: restore.MysqlDatabase, + MariadbDatabase: restore.MariadbDatabase, + MongodbDatabase: restore.MongodbDatabase, + } + } + + // Cache database credentials with 1-hour expiration + s.cacheUtil.SetWithExpiration(restoreID.String(), dbCache, 1*time.Hour) + + leastBusyNodeID, err := s.calculateLeastBusyNode() + if err != nil { + s.logger.Error( + "Failed to calculate least busy node", + "restoreId", + restoreID, + "error", + err, + ) + return err + } + + if err := s.restoreNodesRegistry.IncrementRestoresInProgress(*leastBusyNodeID); err != nil { + s.logger.Error( + "Failed to increment restores in progress", + "nodeId", + leastBusyNodeID, + "restoreId", + restoreID, + "error", + err, + ) + return err + } + + if err := s.restoreNodesRegistry.AssignRestoreToNode(*leastBusyNodeID, restoreID, false); err != nil { + s.logger.Error( + "Failed to submit restore", + "nodeId", + leastBusyNodeID, + "restoreId", + restoreID, + "error", + err, + ) + if decrementErr := s.restoreNodesRegistry.DecrementRestoresInProgress(*leastBusyNodeID); decrementErr != nil { + s.logger.Error( + "Failed to decrement restores in progress after submit failure", + "nodeId", + leastBusyNodeID, + "error", + decrementErr, + ) + } + return err + } + + if relation, exists := s.restoreToNodeRelations[*leastBusyNodeID]; exists { + relation.RestoreIDs = append(relation.RestoreIDs, restoreID) + s.restoreToNodeRelations[*leastBusyNodeID] = relation + } else { + s.restoreToNodeRelations[*leastBusyNodeID] = RestoreToNodeRelation{ + NodeID: *leastBusyNodeID, + RestoreIDs: []uuid.UUID{restoreID}, + } + } + + s.logger.Info( + "Successfully triggered restore", + "restoreId", + restoreID, + "nodeId", + leastBusyNodeID, + ) + + return nil +} + +func (s *RestoresScheduler) calculateLeastBusyNode() (*uuid.UUID, error) { + nodes, err := s.restoreNodesRegistry.GetAvailableNodes() + if err != nil { + return nil, fmt.Errorf("failed to get available nodes: %w", err) + } + + if len(nodes) == 0 { + return nil, fmt.Errorf("no nodes available") + } + + stats, err := s.restoreNodesRegistry.GetRestoreNodesStats() + if err != nil { + return nil, fmt.Errorf("failed to get restore nodes stats: %w", err) + } + + statsMap := make(map[uuid.UUID]int) + for _, stat := range stats { + statsMap[stat.ID] = stat.ActiveRestores + } + + var bestNode *RestoreNode + var bestScore float64 = -1 + + for i := range nodes { + node := &nodes[i] + + activeRestores := statsMap[node.ID] + + var score float64 + if node.ThroughputMBs > 0 { + score = float64(activeRestores) / float64(node.ThroughputMBs) + } else { + score = float64(activeRestores) * 1000 + } + + if bestNode == nil || score < bestScore { + bestNode = node + bestScore = score + } + } + + if bestNode == nil { + return nil, fmt.Errorf("no suitable nodes available") + } + + return &bestNode.ID, nil +} + +func (s *RestoresScheduler) onRestoreCompleted(nodeID uuid.UUID, restoreID uuid.UUID) { + // Verify this task is actually a restore (registry contains multiple task types) + _, err := s.restoreRepository.FindByID(restoreID) + if err != nil { + // Not a restore task, ignore it + return + } + + relation, exists := s.restoreToNodeRelations[nodeID] + if !exists { + s.logger.Warn( + "Received completion for unknown node", + "nodeId", + nodeID, + "restoreId", + restoreID, + ) + return + } + + newRestoreIDs := make([]uuid.UUID, 0) + found := false + for _, id := range relation.RestoreIDs { + if id == restoreID { + found = true + continue + } + newRestoreIDs = append(newRestoreIDs, id) + } + + if !found { + s.logger.Warn( + "Restore not found in node's restore list", + "nodeId", + nodeID, + "restoreId", + restoreID, + ) + return + } + + if len(newRestoreIDs) == 0 { + delete(s.restoreToNodeRelations, nodeID) + } else { + relation.RestoreIDs = newRestoreIDs + s.restoreToNodeRelations[nodeID] = relation + } + + if err := s.restoreNodesRegistry.DecrementRestoresInProgress(nodeID); err != nil { + s.logger.Error( + "Failed to decrement restores in progress", + "nodeId", + nodeID, + "restoreId", + restoreID, + "error", + err, + ) + } +} + +func (s *RestoresScheduler) checkDeadNodesAndFailRestores() error { + nodes, err := s.restoreNodesRegistry.GetAvailableNodes() + if err != nil { + return fmt.Errorf("failed to get available nodes: %w", err) + } + + aliveNodeIDs := make(map[uuid.UUID]bool) + for _, node := range nodes { + aliveNodeIDs[node.ID] = true + } + + for nodeID, relation := range s.restoreToNodeRelations { + if aliveNodeIDs[nodeID] { + continue + } + + s.logger.Warn( + "Node is dead, failing its restores", + "nodeId", + nodeID, + "restoreCount", + len(relation.RestoreIDs), + ) + + for _, restoreID := range relation.RestoreIDs { + restore, err := s.restoreRepository.FindByID(restoreID) + if err != nil { + s.logger.Error( + "Failed to find restore for dead node", + "nodeId", + nodeID, + "restoreId", + restoreID, + "error", + err, + ) + continue + } + + failMessage := "Restore failed due to node unavailability" + restore.FailMessage = &failMessage + restore.Status = restores_core.RestoreStatusFailed + + if err := s.restoreRepository.Save(restore); err != nil { + s.logger.Error( + "Failed to save failed restore for dead node", + "nodeId", + nodeID, + "restoreId", + restoreID, + "error", + err, + ) + continue + } + + if err := s.restoreNodesRegistry.DecrementRestoresInProgress(nodeID); err != nil { + s.logger.Error( + "Failed to decrement restores in progress for dead node", + "nodeId", + nodeID, + "restoreId", + restoreID, + "error", + err, + ) + } + + s.logger.Info( + "Failed restore due to dead node", + "nodeId", + nodeID, + "restoreId", + restoreID, + ) + } + + delete(s.restoreToNodeRelations, nodeID) + } + + return nil +} diff --git a/backend/internal/features/restores/restoring/scheduler_test.go b/backend/internal/features/restores/restoring/scheduler_test.go new file mode 100644 index 0000000..3f39a3f --- /dev/null +++ b/backend/internal/features/restores/restoring/scheduler_test.go @@ -0,0 +1,852 @@ +package restoring + +import ( + "testing" + "time" + + "databasus-backend/internal/features/backups/backups" + backups_core "databasus-backend/internal/features/backups/backups/core" + backups_config "databasus-backend/internal/features/backups/config" + "databasus-backend/internal/features/databases" + "databasus-backend/internal/features/databases/databases/postgresql" + "databasus-backend/internal/features/notifiers" + restores_core "databasus-backend/internal/features/restores/core" + "databasus-backend/internal/features/storages" + users_enums "databasus-backend/internal/features/users/enums" + users_testing "databasus-backend/internal/features/users/testing" + workspaces_testing "databasus-backend/internal/features/workspaces/testing" + cache_utils "databasus-backend/internal/util/cache" + "databasus-backend/internal/util/encryption" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +func Test_CheckDeadNodesAndFailRestores_NodeDies_FailsRestoreAndCleansUpRegistry(t *testing.T) { + cache_utils.ClearAllCache() + + user := users_testing.CreateTestUser(users_enums.UserRoleAdmin) + router := CreateTestRouter() + workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router) + storage := storages.CreateTestStorage(workspace.ID) + notifier := notifiers.CreateTestNotifier(workspace.ID) + database := databases.CreateTestDatabase(workspace.ID, storage, notifier) + + var mockNodeID uuid.UUID + + defer func() { + backupRepo := backups_core.BackupRepository{} + backups, _ := backupRepo.FindByDatabaseID(database.ID) + for _, backup := range backups { + backupRepo.DeleteByID(backup.ID) + } + + restoreRepo := restores_core.RestoreRepository{} + restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress) + for _, restore := range restores { + restoreRepo.DeleteByID(restore.ID) + } + restores, _ = restoreRepo.FindByStatus(restores_core.RestoreStatusFailed) + for _, restore := range restores { + restoreRepo.DeleteByID(restore.ID) + } + + databases.RemoveTestDatabase(database) + time.Sleep(50 * time.Millisecond) + storages.RemoveTestStorage(storage.ID) + notifiers.RemoveTestNotifier(notifier) + workspaces_testing.RemoveTestWorkspace(workspace, router) + + // Clean up mock node + if mockNodeID != uuid.Nil { + restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: mockNodeID}) + } + cache_utils.ClearAllCache() + }() + + backups_config.EnableBackupsForTestDatabase(database.ID, storage) + + // Create a test backup + backup := backups.CreateTestBackup(database.ID, storage.ID) + + var err error + // Register mock node without subscribing to restores (simulates node crash after registration) + mockNodeID = uuid.New() + err = CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC()) + assert.NoError(t, err) + + // Create restore and assign to mock node + restore := &restores_core.Restore{ + BackupID: backup.ID, + Status: restores_core.RestoreStatusInProgress, + } + err = restoreRepository.Save(restore) + assert.NoError(t, err) + + // Scheduler assigns restore to mock node + err = GetRestoresScheduler().StartRestore(restore.ID, nil) + assert.NoError(t, err) + time.Sleep(100 * time.Millisecond) + + // Verify Valkey counter was incremented when restore was assigned + stats, err := restoreNodesRegistry.GetRestoreNodesStats() + assert.NoError(t, err) + foundStat := false + for _, stat := range stats { + if stat.ID == mockNodeID { + assert.Equal(t, 1, stat.ActiveRestores) + foundStat = true + break + } + } + assert.True(t, foundStat, "Node stats should be present") + + // Simulate node death by setting heartbeat older than 2-minute threshold + oldHeartbeat := time.Now().UTC().Add(-3 * time.Minute) + err = UpdateNodeHeartbeatDirectly(mockNodeID, 100, oldHeartbeat) + assert.NoError(t, err) + + // Trigger dead node detection + err = GetRestoresScheduler().checkDeadNodesAndFailRestores() + assert.NoError(t, err) + + // Verify restore was failed with appropriate error message + failedRestore, err := restoreRepository.FindByID(restore.ID) + assert.NoError(t, err) + assert.Equal(t, restores_core.RestoreStatusFailed, failedRestore.Status) + assert.NotNil(t, failedRestore.FailMessage) + assert.Contains(t, *failedRestore.FailMessage, "node unavailability") + + // Verify Valkey counter was decremented after restore failed + stats, err = restoreNodesRegistry.GetRestoreNodesStats() + assert.NoError(t, err) + for _, stat := range stats { + if stat.ID == mockNodeID { + assert.Equal(t, 0, stat.ActiveRestores) + } + } + + time.Sleep(200 * time.Millisecond) +} + +func Test_OnRestoreCompleted_TaskIsNotRestore_SkipsProcessing(t *testing.T) { + cache_utils.ClearAllCache() + + user := users_testing.CreateTestUser(users_enums.UserRoleAdmin) + router := CreateTestRouter() + workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router) + storage := storages.CreateTestStorage(workspace.ID) + notifier := notifiers.CreateTestNotifier(workspace.ID) + database := databases.CreateTestDatabase(workspace.ID, storage, notifier) + + var mockNodeID uuid.UUID + + defer func() { + backupRepo := backups_core.BackupRepository{} + backups, _ := backupRepo.FindByDatabaseID(database.ID) + for _, backup := range backups { + backupRepo.DeleteByID(backup.ID) + } + + restoreRepo := restores_core.RestoreRepository{} + restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress) + for _, restore := range restores { + restoreRepo.DeleteByID(restore.ID) + } + + databases.RemoveTestDatabase(database) + time.Sleep(50 * time.Millisecond) + storages.RemoveTestStorage(storage.ID) + notifiers.RemoveTestNotifier(notifier) + workspaces_testing.RemoveTestWorkspace(workspace, router) + + // Clean up mock node + if mockNodeID != uuid.Nil { + restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: mockNodeID}) + } + cache_utils.ClearAllCache() + }() + + backups_config.EnableBackupsForTestDatabase(database.ID, storage) + + // Create a test backup + backup := backups.CreateTestBackup(database.ID, storage.ID) + + // Register mock node + mockNodeID = uuid.New() + err := CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC()) + assert.NoError(t, err) + + // Create restore and assign to the node + restore := &restores_core.Restore{ + BackupID: backup.ID, + Status: restores_core.RestoreStatusInProgress, + } + err = restoreRepository.Save(restore) + assert.NoError(t, err) + + err = GetRestoresScheduler().StartRestore(restore.ID, nil) + assert.NoError(t, err) + time.Sleep(100 * time.Millisecond) + + // Get initial state of the registry + initialStats, err := restoreNodesRegistry.GetRestoreNodesStats() + assert.NoError(t, err) + var initialActiveTasks int + for _, stat := range initialStats { + if stat.ID == mockNodeID { + initialActiveTasks = stat.ActiveRestores + break + } + } + assert.Equal(t, 1, initialActiveTasks, "Should have 1 active task") + + // Call onRestoreCompleted with a random UUID (not a restore ID) + nonRestoreTaskID := uuid.New() + GetRestoresScheduler().onRestoreCompleted(mockNodeID, nonRestoreTaskID) + + time.Sleep(100 * time.Millisecond) + + // Verify: Active tasks counter should remain the same (not decremented) + stats, err := restoreNodesRegistry.GetRestoreNodesStats() + assert.NoError(t, err) + for _, stat := range stats { + if stat.ID == mockNodeID { + assert.Equal(t, initialActiveTasks, stat.ActiveRestores, + "Active tasks should not change for non-restore task") + } + } + + // Verify: restore should still be in progress (not modified) + unchangedRestore, err := restoreRepository.FindByID(restore.ID) + assert.NoError(t, err) + assert.Equal(t, restores_core.RestoreStatusInProgress, unchangedRestore.Status, + "Restore status should not change for non-restore task completion") + + // Verify: restoreToNodeRelations should still contain the node + scheduler := GetRestoresScheduler() + _, exists := scheduler.restoreToNodeRelations[mockNodeID] + assert.True(t, exists, "Node should still be in restoreToNodeRelations") + + time.Sleep(200 * time.Millisecond) +} + +func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) { + t.Run("Nodes with same throughput", func(t *testing.T) { + cache_utils.ClearAllCache() + + node1ID := uuid.New() + node2ID := uuid.New() + node3ID := uuid.New() + now := time.Now().UTC() + + defer func() { + // Clean up all mock nodes + restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: node1ID}) + restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: node2ID}) + restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: node3ID}) + cache_utils.ClearAllCache() + }() + + err := CreateMockNodeInRegistry(node1ID, 100, now) + assert.NoError(t, err) + err = CreateMockNodeInRegistry(node2ID, 100, now) + assert.NoError(t, err) + err = CreateMockNodeInRegistry(node3ID, 100, now) + assert.NoError(t, err) + + for range 5 { + err = restoreNodesRegistry.IncrementRestoresInProgress(node1ID) + assert.NoError(t, err) + } + + for range 2 { + err = restoreNodesRegistry.IncrementRestoresInProgress(node2ID) + assert.NoError(t, err) + } + + for range 8 { + err = restoreNodesRegistry.IncrementRestoresInProgress(node3ID) + assert.NoError(t, err) + } + + leastBusyNodeID, err := GetRestoresScheduler().calculateLeastBusyNode() + assert.NoError(t, err) + assert.NotNil(t, leastBusyNodeID) + assert.Equal(t, node2ID, *leastBusyNodeID) + }) + + t.Run("Nodes with different throughput", func(t *testing.T) { + cache_utils.ClearAllCache() + + node100MBsID := uuid.New() + node50MBsID := uuid.New() + now := time.Now().UTC() + + defer func() { + // Clean up all mock nodes + restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: node100MBsID}) + restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: node50MBsID}) + cache_utils.ClearAllCache() + }() + + err := CreateMockNodeInRegistry(node100MBsID, 100, now) + assert.NoError(t, err) + err = CreateMockNodeInRegistry(node50MBsID, 50, now) + assert.NoError(t, err) + + for range 10 { + err = restoreNodesRegistry.IncrementRestoresInProgress(node100MBsID) + assert.NoError(t, err) + } + + err = restoreNodesRegistry.IncrementRestoresInProgress(node50MBsID) + assert.NoError(t, err) + + leastBusyNodeID, err := GetRestoresScheduler().calculateLeastBusyNode() + assert.NoError(t, err) + assert.NotNil(t, leastBusyNodeID) + assert.Equal(t, node50MBsID, *leastBusyNodeID) + }) +} + +func Test_FailRestoresInProgress_SchedulerStarts_UpdatesStatus(t *testing.T) { + cache_utils.ClearAllCache() + + user := users_testing.CreateTestUser(users_enums.UserRoleAdmin) + router := CreateTestRouter() + workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router) + storage := storages.CreateTestStorage(workspace.ID) + notifier := notifiers.CreateTestNotifier(workspace.ID) + database := databases.CreateTestDatabase(workspace.ID, storage, notifier) + + defer func() { + backupRepo := backups_core.BackupRepository{} + backups, _ := backupRepo.FindByDatabaseID(database.ID) + for _, backup := range backups { + backupRepo.DeleteByID(backup.ID) + } + + restoreRepo := restores_core.RestoreRepository{} + + restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress) + for _, restore := range restores { + restoreRepo.DeleteByID(restore.ID) + } + + restores, _ = restoreRepo.FindByStatus(restores_core.RestoreStatusFailed) + for _, restore := range restores { + restoreRepo.DeleteByID(restore.ID) + } + + restores, _ = restoreRepo.FindByStatus(restores_core.RestoreStatusCompleted) + for _, restore := range restores { + restoreRepo.DeleteByID(restore.ID) + } + + databases.RemoveTestDatabase(database) + time.Sleep(50 * time.Millisecond) + storages.RemoveTestStorage(storage.ID) + notifiers.RemoveTestNotifier(notifier) + workspaces_testing.RemoveTestWorkspace(workspace, router) + + cache_utils.ClearAllCache() + }() + + backups_config.EnableBackupsForTestDatabase(database.ID, storage) + + // Create a test backup + backup := backups.CreateTestBackup(database.ID, storage.ID) + + // Create two in-progress restores that should be failed on scheduler restart + restore1 := &restores_core.Restore{ + BackupID: backup.ID, + Status: restores_core.RestoreStatusInProgress, + CreatedAt: time.Now().UTC().Add(-30 * time.Minute), + } + err := restoreRepository.Save(restore1) + assert.NoError(t, err) + + restore2 := &restores_core.Restore{ + BackupID: backup.ID, + Status: restores_core.RestoreStatusInProgress, + CreatedAt: time.Now().UTC().Add(-15 * time.Minute), + } + err = restoreRepository.Save(restore2) + assert.NoError(t, err) + + // Create a completed restore to verify it's not affected by failRestoresInProgress + completedRestore := &restores_core.Restore{ + BackupID: backup.ID, + Status: restores_core.RestoreStatusCompleted, + CreatedAt: time.Now().UTC().Add(-1 * time.Hour), + } + err = restoreRepository.Save(completedRestore) + assert.NoError(t, err) + + // Trigger the scheduler's failRestoresInProgress logic + // This should mark in-progress restores as failed + err = GetRestoresScheduler().failRestoresInProgress() + assert.NoError(t, err) + + // Verify all restores exist and were processed correctly + allRestores1, err := restoreRepository.FindByID(restore1.ID) + assert.NoError(t, err) + allRestores2, err := restoreRepository.FindByID(restore2.ID) + assert.NoError(t, err) + allRestores3, err := restoreRepository.FindByID(completedRestore.ID) + assert.NoError(t, err) + + var failedCount int + var completedCount int + + restoresToCheck := []*restores_core.Restore{allRestores1, allRestores2, allRestores3} + for _, restore := range restoresToCheck { + switch restore.Status { + case restores_core.RestoreStatusFailed: + failedCount++ + // Verify fail message indicates application restart + assert.NotNil(t, restore.FailMessage) + assert.Equal(t, "Restore failed due to application restart", *restore.FailMessage) + case restores_core.RestoreStatusCompleted: + completedCount++ + } + } + + // Verify correct number of restores in each state + assert.Equal(t, 2, failedCount, "Should have 2 failed restores (originally in progress)") + assert.Equal(t, 1, completedCount, "Should have 1 completed restore (unchanged)") + + time.Sleep(200 * time.Millisecond) +} + +func Test_StartRestore_RestoreCompletes_DecrementsActiveTaskCount(t *testing.T) { + cache_utils.ClearAllCache() + + // Start scheduler so it can handle task completions + schedulerCancel := StartSchedulerForTest(t) + defer schedulerCancel() + + restorerNode := CreateTestRestorerNode() + restorerNode.restoreBackupUsecase = &MockSuccessRestoreUsecase{} + + cancel := StartRestorerNodeForTest(t, restorerNode) + defer StopRestorerNodeForTest(t, cancel, restorerNode) + + user := users_testing.CreateTestUser(users_enums.UserRoleAdmin) + router := CreateTestRouter() + workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router) + storage := storages.CreateTestStorage(workspace.ID) + notifier := notifiers.CreateTestNotifier(workspace.ID) + database := databases.CreateTestDatabase(workspace.ID, storage, notifier) + + defer func() { + backupRepo := backups_core.BackupRepository{} + backups, _ := backupRepo.FindByDatabaseID(database.ID) + for _, backup := range backups { + backupRepo.DeleteByID(backup.ID) + } + + restoreRepo := restores_core.RestoreRepository{} + restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusCompleted) + for _, restore := range restores { + restoreRepo.DeleteByID(restore.ID) + } + + databases.RemoveTestDatabase(database) + time.Sleep(50 * time.Millisecond) + storages.RemoveTestStorage(storage.ID) + notifiers.RemoveTestNotifier(notifier) + workspaces_testing.RemoveTestWorkspace(workspace, router) + }() + + backups_config.EnableBackupsForTestDatabase(database.ID, storage) + + // Create a test backup + backup := backups.CreateTestBackup(database.ID, storage.ID) + + // Get initial active task count + stats, err := restoreNodesRegistry.GetRestoreNodesStats() + assert.NoError(t, err) + var initialActiveTasks int + for _, stat := range stats { + if stat.ID == restorerNode.nodeID { + initialActiveTasks = stat.ActiveRestores + break + } + } + t.Logf("Initial active tasks: %d", initialActiveTasks) + + // Create and start restore + restore := &restores_core.Restore{ + BackupID: backup.ID, + Status: restores_core.RestoreStatusInProgress, + } + err = restoreRepository.Save(restore) + assert.NoError(t, err) + + err = GetRestoresScheduler().StartRestore(restore.ID, nil) + assert.NoError(t, err) + + // Wait for restore to complete + WaitForRestoreCompletion(t, restore.ID, 10*time.Second) + + // Verify restore was completed + completedRestore, err := restoreRepository.FindByID(restore.ID) + assert.NoError(t, err) + assert.Equal(t, restores_core.RestoreStatusCompleted, completedRestore.Status) + + // Wait for active task count to decrease + decreased := WaitForActiveTasksDecrease( + t, + restorerNode.nodeID, + initialActiveTasks+1, + 10*time.Second, + ) + assert.True(t, decreased, "Active task count should have decreased after restore completion") + + // Verify final active task count equals initial count + finalStats, err := restoreNodesRegistry.GetRestoreNodesStats() + assert.NoError(t, err) + for _, stat := range finalStats { + if stat.ID == restorerNode.nodeID { + t.Logf("Final active tasks: %d", stat.ActiveRestores) + assert.Equal(t, initialActiveTasks, stat.ActiveRestores, + "Active task count should return to initial value after restore completion") + break + } + } + + time.Sleep(200 * time.Millisecond) +} + +func Test_StartRestore_RestoreFails_DecrementsActiveTaskCount(t *testing.T) { + cache_utils.ClearAllCache() + + // Start scheduler so it can handle task completions + schedulerCancel := StartSchedulerForTest(t) + defer schedulerCancel() + + restorerNode := CreateTestRestorerNode() + restorerNode.restoreBackupUsecase = &MockFailedRestoreUsecase{} + + cancel := StartRestorerNodeForTest(t, restorerNode) + defer StopRestorerNodeForTest(t, cancel, restorerNode) + + user := users_testing.CreateTestUser(users_enums.UserRoleAdmin) + router := CreateTestRouter() + workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router) + storage := storages.CreateTestStorage(workspace.ID) + notifier := notifiers.CreateTestNotifier(workspace.ID) + database := databases.CreateTestDatabase(workspace.ID, storage, notifier) + + defer func() { + backupRepo := backups_core.BackupRepository{} + backups, _ := backupRepo.FindByDatabaseID(database.ID) + for _, backup := range backups { + backupRepo.DeleteByID(backup.ID) + } + + restoreRepo := restores_core.RestoreRepository{} + restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusFailed) + for _, restore := range restores { + restoreRepo.DeleteByID(restore.ID) + } + + databases.RemoveTestDatabase(database) + time.Sleep(50 * time.Millisecond) + storages.RemoveTestStorage(storage.ID) + notifiers.RemoveTestNotifier(notifier) + workspaces_testing.RemoveTestWorkspace(workspace, router) + }() + + backups_config.EnableBackupsForTestDatabase(database.ID, storage) + + // Create a test backup + backup := backups.CreateTestBackup(database.ID, storage.ID) + + // Get initial active task count + stats, err := restoreNodesRegistry.GetRestoreNodesStats() + assert.NoError(t, err) + var initialActiveTasks int + for _, stat := range stats { + if stat.ID == restorerNode.nodeID { + initialActiveTasks = stat.ActiveRestores + break + } + } + t.Logf("Initial active tasks: %d", initialActiveTasks) + + // Create and start restore + restore := &restores_core.Restore{ + BackupID: backup.ID, + Status: restores_core.RestoreStatusInProgress, + } + err = restoreRepository.Save(restore) + assert.NoError(t, err) + + err = GetRestoresScheduler().StartRestore(restore.ID, nil) + assert.NoError(t, err) + + // Wait for restore to fail + WaitForRestoreCompletion(t, restore.ID, 10*time.Second) + + // Verify restore failed + failedRestore, err := restoreRepository.FindByID(restore.ID) + assert.NoError(t, err) + assert.Equal(t, restores_core.RestoreStatusFailed, failedRestore.Status) + + // Wait for active task count to decrease + decreased := WaitForActiveTasksDecrease( + t, + restorerNode.nodeID, + initialActiveTasks+1, + 10*time.Second, + ) + assert.True(t, decreased, "Active task count should have decreased after restore failure") + + // Verify final active task count equals initial count + finalStats, err := restoreNodesRegistry.GetRestoreNodesStats() + assert.NoError(t, err) + for _, stat := range finalStats { + if stat.ID == restorerNode.nodeID { + t.Logf("Final active tasks: %d", stat.ActiveRestores) + assert.Equal(t, initialActiveTasks, stat.ActiveRestores, + "Active task count should return to initial value after restore failure") + break + } + } + + time.Sleep(200 * time.Millisecond) +} + +func Test_StartRestore_CredentialsStoredEncryptedInCache(t *testing.T) { + cache_utils.ClearAllCache() + + user := users_testing.CreateTestUser(users_enums.UserRoleAdmin) + router := CreateTestRouter() + workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router) + storage := storages.CreateTestStorage(workspace.ID) + notifier := notifiers.CreateTestNotifier(workspace.ID) + database := databases.CreateTestDatabase(workspace.ID, storage, notifier) + + var mockNodeID uuid.UUID + + defer func() { + backupRepo := backups_core.BackupRepository{} + backups, _ := backupRepo.FindByDatabaseID(database.ID) + for _, backup := range backups { + backupRepo.DeleteByID(backup.ID) + } + + restoreRepo := restores_core.RestoreRepository{} + restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress) + for _, restore := range restores { + restoreRepo.DeleteByID(restore.ID) + } + + databases.RemoveTestDatabase(database) + time.Sleep(50 * time.Millisecond) + storages.RemoveTestStorage(storage.ID) + notifiers.RemoveTestNotifier(notifier) + workspaces_testing.RemoveTestWorkspace(workspace, router) + + // Clean up mock node + if mockNodeID != uuid.Nil { + restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: mockNodeID}) + } + cache_utils.ClearAllCache() + }() + + backups_config.EnableBackupsForTestDatabase(database.ID, storage) + + // Create a test backup + backup := backups.CreateTestBackup(database.ID, storage.ID) + + // Register mock node so scheduler can assign restore to it + mockNodeID = uuid.New() + err := CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC()) + assert.NoError(t, err) + + // Create restore with plaintext credentials + plaintextPassword := "test_password_123" + restore := &restores_core.Restore{ + BackupID: backup.ID, + Status: restores_core.RestoreStatusInProgress, + } + err = restoreRepository.Save(restore) + assert.NoError(t, err) + + // Create PostgreSQL database credentials with plaintext password + postgresDB := &postgresql.PostgresqlDatabase{ + Host: "localhost", + Port: 5432, + Username: "testuser", + Password: plaintextPassword, + Database: stringPtr("testdb"), + Version: "16", + } + + // Encrypt password using FieldEncryptor (same as production flow) + encryptor := encryption.GetFieldEncryptor() + err = postgresDB.EncryptSensitiveFields(database.ID, encryptor) + assert.NoError(t, err) + + // Verify password was encrypted (different from plaintext) + assert.NotEqual(t, plaintextPassword, postgresDB.Password, + "Password should be encrypted, not plaintext") + + // Create cache with encrypted credentials + dbCache := &RestoreDatabaseCache{ + PostgresqlDatabase: postgresDB, + } + + // Call StartRestore to cache credentials (do NOT start restore node) + err = GetRestoresScheduler().StartRestore(restore.ID, dbCache) + assert.NoError(t, err) + + // Directly read from cache + cachedData := restoreDatabaseCache.Get(restore.ID.String()) + assert.NotNil(t, cachedData, "Cache entry should exist") + assert.NotNil(t, cachedData.PostgresqlDatabase, "PostgreSQL credentials should be cached") + + // Verify password in cache is encrypted (not plaintext) + assert.NotEqual(t, plaintextPassword, cachedData.PostgresqlDatabase.Password, + "Cached password should be encrypted, not plaintext") + assert.Equal(t, postgresDB.Password, cachedData.PostgresqlDatabase.Password, + "Cached password should match the encrypted version") + + // Verify other fields are present + assert.Equal(t, "localhost", cachedData.PostgresqlDatabase.Host) + assert.Equal(t, 5432, cachedData.PostgresqlDatabase.Port) + assert.Equal(t, "testuser", cachedData.PostgresqlDatabase.Username) + assert.Equal(t, "testdb", *cachedData.PostgresqlDatabase.Database) + + time.Sleep(200 * time.Millisecond) +} + +func Test_StartRestore_CredentialsRemovedAfterRestoreStarts(t *testing.T) { + cache_utils.ClearAllCache() + + // Start scheduler so it can handle task assignments + schedulerCancel := StartSchedulerForTest(t) + defer schedulerCancel() + + // Create mock restorer node with credential capture usecase + restorerNode := CreateTestRestorerNode() + calledChan := make(chan *databases.Database, 1) + restorerNode.restoreBackupUsecase = &MockCaptureCredentialsRestoreUsecase{ + CalledChan: calledChan, + ShouldSucceed: true, + } + + cancel := StartRestorerNodeForTest(t, restorerNode) + defer StopRestorerNodeForTest(t, cancel, restorerNode) + + user := users_testing.CreateTestUser(users_enums.UserRoleAdmin) + router := CreateTestRouter() + workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router) + storage := storages.CreateTestStorage(workspace.ID) + notifier := notifiers.CreateTestNotifier(workspace.ID) + database := databases.CreateTestDatabase(workspace.ID, storage, notifier) + + defer func() { + backupRepo := backups_core.BackupRepository{} + backups, _ := backupRepo.FindByDatabaseID(database.ID) + for _, backup := range backups { + backupRepo.DeleteByID(backup.ID) + } + + restoreRepo := restores_core.RestoreRepository{} + restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusCompleted) + for _, restore := range restores { + restoreRepo.DeleteByID(restore.ID) + } + + databases.RemoveTestDatabase(database) + time.Sleep(50 * time.Millisecond) + storages.RemoveTestStorage(storage.ID) + notifiers.RemoveTestNotifier(notifier) + workspaces_testing.RemoveTestWorkspace(workspace, router) + + cache_utils.ClearAllCache() + }() + + backups_config.EnableBackupsForTestDatabase(database.ID, storage) + + // Create a test backup + backup := backups.CreateTestBackup(database.ID, storage.ID) + + // Create restore with credentials + plaintextPassword := "test_password_456" + restore := &restores_core.Restore{ + BackupID: backup.ID, + Status: restores_core.RestoreStatusInProgress, + } + err := restoreRepository.Save(restore) + assert.NoError(t, err) + + // Create PostgreSQL database credentials + // Database field is nil to avoid PopulateDbData trying to connect + postgresDB := &postgresql.PostgresqlDatabase{ + Host: "localhost", + Port: 5432, + Username: "testuser", + Password: plaintextPassword, + Database: nil, + Version: "16", + } + + // Encrypt password (same as production flow) + encryptor := encryption.GetFieldEncryptor() + err = postgresDB.EncryptSensitiveFields(database.ID, encryptor) + assert.NoError(t, err) + + encryptedPassword := postgresDB.Password + + // Create cache with encrypted credentials + dbCache := &RestoreDatabaseCache{ + PostgresqlDatabase: postgresDB, + } + + // Call StartRestore to cache credentials and trigger restore + err = GetRestoresScheduler().StartRestore(restore.ID, dbCache) + assert.NoError(t, err) + + // Wait for mock usecase to be called (with timeout) + var capturedDB *databases.Database + select { + case capturedDB = <-calledChan: + t.Log("Mock usecase was called, credentials captured") + case <-time.After(10 * time.Second): + t.Fatal("Timeout waiting for mock usecase to be called") + } + + // Verify cache is empty after restore starts (credentials were deleted) + cacheAfterExecution := restoreDatabaseCache.Get(restore.ID.String()) + assert.Nil(t, cacheAfterExecution, "Cache should be empty after restore execution starts") + + // Verify mock received valid credentials + assert.NotNil(t, capturedDB, "Captured database should not be nil") + assert.NotNil(t, capturedDB.Postgresql, "PostgreSQL credentials should be provided to usecase") + assert.Equal(t, "localhost", capturedDB.Postgresql.Host) + assert.Equal(t, 5432, capturedDB.Postgresql.Port) + assert.Equal(t, "testuser", capturedDB.Postgresql.Username) + assert.NotEmpty(t, capturedDB.Postgresql.Password, "Password should be provided to usecase") + + // Note: Password at this point may still be encrypted because PopulateDbData + // is called after the mock captures it. The important thing is that credentials + // were provided to the usecase despite cache being deleted. + t.Logf("Encrypted password in cache: %s", encryptedPassword) + t.Logf("Password received by usecase: %s", capturedDB.Postgresql.Password) + + // Wait for restore to complete + WaitForRestoreCompletion(t, restore.ID, 10*time.Second) + + // Verify restore was completed + completedRestore, err := restoreRepository.FindByID(restore.ID) + assert.NoError(t, err) + assert.Equal(t, restores_core.RestoreStatusCompleted, completedRestore.Status) + + time.Sleep(200 * time.Millisecond) +} diff --git a/backend/internal/features/restores/restoring/testing.go b/backend/internal/features/restores/restoring/testing.go new file mode 100644 index 0000000..e202295 --- /dev/null +++ b/backend/internal/features/restores/restoring/testing.go @@ -0,0 +1,297 @@ +package restoring + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + + "databasus-backend/internal/features/backups/backups" + backups_core "databasus-backend/internal/features/backups/backups/core" + backups_config "databasus-backend/internal/features/backups/config" + "databasus-backend/internal/features/databases" + "databasus-backend/internal/features/databases/databases/postgresql" + restores_core "databasus-backend/internal/features/restores/core" + "databasus-backend/internal/features/restores/usecases" + "databasus-backend/internal/features/storages" + workspaces_controllers "databasus-backend/internal/features/workspaces/controllers" + workspaces_testing "databasus-backend/internal/features/workspaces/testing" + "databasus-backend/internal/util/encryption" + "databasus-backend/internal/util/logger" +) + +func CreateTestRouter() *gin.Engine { + router := workspaces_testing.CreateTestRouter( + workspaces_controllers.GetWorkspaceController(), + workspaces_controllers.GetMembershipController(), + databases.GetDatabaseController(), + backups_config.GetBackupConfigController(), + ) + + return router +} + +func CreateTestRestorerNode() *RestorerNode { + return &RestorerNode{ + uuid.New(), + databases.GetDatabaseService(), + backups.GetBackupService(), + encryption.GetFieldEncryptor(), + restoreRepository, + backups_config.GetBackupConfigService(), + storages.GetStorageService(), + restoreNodesRegistry, + logger.GetLogger(), + usecases.GetRestoreBackupUsecase(), + restoreDatabaseCache, + time.Time{}, + } +} + +// WaitForRestoreCompletion waits for a restore to be completed (or failed) +func WaitForRestoreCompletion( + t *testing.T, + restoreID uuid.UUID, + timeout time.Duration, +) { + deadline := time.Now().UTC().Add(timeout) + + for time.Now().UTC().Before(deadline) { + restore, err := restoreRepository.FindByID(restoreID) + if err != nil { + t.Logf("WaitForRestoreCompletion: error finding restore: %v", err) + time.Sleep(50 * time.Millisecond) + continue + } + + t.Logf("WaitForRestoreCompletion: restore status: %s", restore.Status) + + if restore.Status == restores_core.RestoreStatusCompleted || + restore.Status == restores_core.RestoreStatusFailed { + t.Logf( + "WaitForRestoreCompletion: restore finished with status %s", + restore.Status, + ) + return + } + + time.Sleep(50 * time.Millisecond) + } + + t.Logf("WaitForRestoreCompletion: timeout waiting for restore to complete") +} + +// StartRestorerNodeForTest starts a RestorerNode in a goroutine for testing. +// The node registers itself in the registry and subscribes to restore assignments. +// Returns a context cancel function that should be deferred to stop the node. +func StartRestorerNodeForTest(t *testing.T, restorerNode *RestorerNode) context.CancelFunc { + ctx, cancel := context.WithCancel(context.Background()) + + done := make(chan struct{}) + + go func() { + restorerNode.Run(ctx) + close(done) + }() + + // Poll registry for node presence instead of fixed sleep + deadline := time.Now().UTC().Add(5 * time.Second) + for time.Now().UTC().Before(deadline) { + nodes, err := restoreNodesRegistry.GetAvailableNodes() + if err == nil { + for _, node := range nodes { + if node.ID == restorerNode.nodeID { + t.Logf("RestorerNode registered in registry: %s", restorerNode.nodeID) + + return func() { + cancel() + select { + case <-done: + t.Log("RestorerNode stopped gracefully") + case <-time.After(2 * time.Second): + t.Log("RestorerNode stop timeout") + } + } + } + } + } + time.Sleep(50 * time.Millisecond) + } + + t.Fatalf("RestorerNode failed to register in registry within timeout") + return nil +} + +// StartSchedulerForTest starts the RestoresScheduler in a goroutine for testing. +// The scheduler subscribes to task completions and manages restore lifecycle. +// Returns a context cancel function that should be deferred to stop the scheduler. +func StartSchedulerForTest(t *testing.T) context.CancelFunc { + ctx, cancel := context.WithCancel(context.Background()) + + done := make(chan struct{}) + + go func() { + GetRestoresScheduler().Run(ctx) + close(done) + }() + + // Give scheduler time to subscribe to completions + time.Sleep(100 * time.Millisecond) + t.Log("RestoresScheduler started") + + return func() { + cancel() + select { + case <-done: + t.Log("RestoresScheduler stopped gracefully") + case <-time.After(2 * time.Second): + t.Log("RestoresScheduler stop timeout") + } + } +} + +// StopRestorerNodeForTest stops the RestorerNode by canceling its context. +// It waits for the node to unregister from the registry. +func StopRestorerNodeForTest(t *testing.T, cancel context.CancelFunc, restorerNode *RestorerNode) { + cancel() + + // Wait for node to unregister from registry + deadline := time.Now().UTC().Add(2 * time.Second) + for time.Now().UTC().Before(deadline) { + nodes, err := restoreNodesRegistry.GetAvailableNodes() + if err == nil { + found := false + for _, node := range nodes { + if node.ID == restorerNode.nodeID { + found = true + break + } + } + if !found { + t.Logf("RestorerNode unregistered from registry: %s", restorerNode.nodeID) + return + } + } + time.Sleep(50 * time.Millisecond) + } + + t.Logf("RestorerNode stop completed for %s", restorerNode.nodeID) +} + +func CreateMockNodeInRegistry(nodeID uuid.UUID, throughputMBs int, lastHeartbeat time.Time) error { + restoreNode := RestoreNode{ + ID: nodeID, + ThroughputMBs: throughputMBs, + LastHeartbeat: lastHeartbeat, + } + + return restoreNodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, restoreNode) +} + +func UpdateNodeHeartbeatDirectly( + nodeID uuid.UUID, + throughputMBs int, + lastHeartbeat time.Time, +) error { + restoreNode := RestoreNode{ + ID: nodeID, + ThroughputMBs: throughputMBs, + LastHeartbeat: lastHeartbeat, + } + + return restoreNodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, restoreNode) +} + +func GetNodeFromRegistry(nodeID uuid.UUID) (*RestoreNode, error) { + nodes, err := restoreNodesRegistry.GetAvailableNodes() + if err != nil { + return nil, err + } + + for _, node := range nodes { + if node.ID == nodeID { + return &node, nil + } + } + + return nil, fmt.Errorf("node not found") +} + +// WaitForActiveTasksDecrease waits for the active task count to decrease below the initial count. +// It polls the registry every 500ms until the count decreases or the timeout is reached. +// Returns true if the count decreased, false if timeout was reached. +func WaitForActiveTasksDecrease( + t *testing.T, + nodeID uuid.UUID, + initialCount int, + timeout time.Duration, +) bool { + deadline := time.Now().UTC().Add(timeout) + + for time.Now().UTC().Before(deadline) { + stats, err := restoreNodesRegistry.GetRestoreNodesStats() + if err != nil { + t.Logf("WaitForActiveTasksDecrease: error getting node stats: %v", err) + time.Sleep(500 * time.Millisecond) + continue + } + + for _, stat := range stats { + if stat.ID == nodeID { + t.Logf( + "WaitForActiveTasksDecrease: current active tasks = %d (initial = %d)", + stat.ActiveRestores, + initialCount, + ) + if stat.ActiveRestores < initialCount { + t.Logf( + "WaitForActiveTasksDecrease: active tasks decreased from %d to %d", + initialCount, + stat.ActiveRestores, + ) + return true + } + break + } + } + + time.Sleep(500 * time.Millisecond) + } + + t.Logf("WaitForActiveTasksDecrease: timeout waiting for active tasks to decrease") + return false +} + +// CreateTestRestore creates a test restore with the given backup and status +func CreateTestRestore( + t *testing.T, + backup *backups_core.Backup, + status restores_core.RestoreStatus, +) *restores_core.Restore { + restore := &restores_core.Restore{ + BackupID: backup.ID, + Status: status, + PostgresqlDatabase: &postgresql.PostgresqlDatabase{ + Host: "localhost", + Port: 5432, + Username: "test", + Password: "test", + Database: stringPtr("testdb"), + Version: "16", + }, + } + + err := restoreRepository.Save(restore) + if err != nil { + t.Fatalf("Failed to create test restore: %v", err) + } + + return restore +} + +func stringPtr(s string) *string { + return &s +} diff --git a/backend/internal/features/restores/service.go b/backend/internal/features/restores/service.go index 5c0cb98..564b9f0 100644 --- a/backend/internal/features/restores/service.go +++ b/backend/internal/features/restores/service.go @@ -7,8 +7,8 @@ import ( backups_config "databasus-backend/internal/features/backups/config" "databasus-backend/internal/features/databases" "databasus-backend/internal/features/disk" - "databasus-backend/internal/features/restores/enums" - "databasus-backend/internal/features/restores/models" + restores_core "databasus-backend/internal/features/restores/core" + "databasus-backend/internal/features/restores/restoring" "databasus-backend/internal/features/restores/usecases" "databasus-backend/internal/features/storages" users_models "databasus-backend/internal/features/users/models" @@ -25,7 +25,7 @@ import ( type RestoreService struct { backupService *backups.BackupService - restoreRepository *RestoreRepository + restoreRepository *restores_core.RestoreRepository storageService *storages.StorageService backupConfigService *backups_config.BackupConfigService restoreBackupUsecase *usecases.RestoreBackupUsecase @@ -44,7 +44,7 @@ func (s *RestoreService) OnBeforeBackupRemove(backup *backups_core.Backup) error } for _, restore := range restores { - if restore.Status == enums.RestoreStatusInProgress { + if restore.Status == restores_core.RestoreStatusInProgress { return errors.New("restore is in progress, backup cannot be removed") } } @@ -61,7 +61,7 @@ func (s *RestoreService) OnBeforeBackupRemove(backup *backups_core.Backup) error func (s *RestoreService) GetRestores( user *users_models.User, backupID uuid.UUID, -) ([]*models.Restore, error) { +) ([]*restores_core.Restore, error) { backup, err := s.backupService.GetBackup(backupID) if err != nil { return nil, err @@ -93,7 +93,7 @@ func (s *RestoreService) GetRestores( func (s *RestoreService) RestoreBackupWithAuth( user *users_models.User, backupID uuid.UUID, - requestDTO RestoreBackupRequest, + requestDTO restores_core.RestoreBackupRequest, ) error { backup, err := s.backupService.GetBackup(backupID) if err != nil { @@ -134,11 +134,45 @@ func (s *RestoreService) RestoreBackupWithAuth( return err } - go func() { - if err := s.RestoreBackup(backup, requestDTO); err != nil { - s.logger.Error("Failed to restore backup", "error", err) + // Create restore record with the request configuration + restore := restores_core.Restore{ + ID: uuid.New(), + Status: restores_core.RestoreStatusInProgress, + BackupID: backup.ID, + Backup: backup, + CreatedAt: time.Now().UTC(), + RestoreDurationMs: 0, + FailMessage: nil, + PostgresqlDatabase: requestDTO.PostgresqlDatabase, + MysqlDatabase: requestDTO.MysqlDatabase, + MariadbDatabase: requestDTO.MariadbDatabase, + MongodbDatabase: requestDTO.MongodbDatabase, + } + + if err := s.restoreRepository.Save(&restore); err != nil { + return err + } + + // Prepare database cache with credentials from the request + dbCache := &restoring.RestoreDatabaseCache{ + PostgresqlDatabase: requestDTO.PostgresqlDatabase, + MysqlDatabase: requestDTO.MysqlDatabase, + MariadbDatabase: requestDTO.MariadbDatabase, + MongodbDatabase: requestDTO.MongodbDatabase, + } + + // Trigger restore via scheduler + scheduler := restoring.GetRestoresScheduler() + if err := scheduler.StartRestore(restore.ID, dbCache); err != nil { + // Mark restore as failed if we can't schedule it + failMsg := fmt.Sprintf("Failed to schedule restore: %v", err) + restore.FailMessage = &failMsg + restore.Status = restores_core.RestoreStatusFailed + if saveErr := s.restoreRepository.Save(&restore); saveErr != nil { + s.logger.Error("Failed to save restore after scheduling error", "error", saveErr) } - }() + return err + } s.auditLogService.WriteAuditLog( fmt.Sprintf( @@ -153,127 +187,9 @@ func (s *RestoreService) RestoreBackupWithAuth( return nil } -func (s *RestoreService) RestoreBackup( - backup *backups_core.Backup, - requestDTO RestoreBackupRequest, -) error { - if backup.Status != backups_core.BackupStatusCompleted { - return errors.New("backup is not completed") - } - - database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID) - if err != nil { - return err - } - - switch database.Type { - case databases.DatabaseTypePostgres: - if requestDTO.PostgresqlDatabase == nil { - return errors.New("postgresql database is required") - } - case databases.DatabaseTypeMysql: - if requestDTO.MysqlDatabase == nil { - return errors.New("mysql database is required") - } - case databases.DatabaseTypeMariadb: - if requestDTO.MariadbDatabase == nil { - return errors.New("mariadb database is required") - } - case databases.DatabaseTypeMongodb: - if requestDTO.MongodbDatabase == nil { - return errors.New("mongodb database is required") - } - } - - restore := models.Restore{ - ID: uuid.New(), - Status: enums.RestoreStatusInProgress, - - BackupID: backup.ID, - Backup: backup, - - CreatedAt: time.Now().UTC(), - RestoreDurationMs: 0, - - FailMessage: nil, - } - - // Save the restore first - if err := s.restoreRepository.Save(&restore); err != nil { - return err - } - - // Save the restore again to include the postgresql database - if err := s.restoreRepository.Save(&restore); err != nil { - return err - } - - storage, err := s.storageService.GetStorageByID(backup.StorageID) - if err != nil { - return err - } - - backupConfig, err := s.backupConfigService.GetBackupConfigByDbId( - database.ID, - ) - if err != nil { - return err - } - - start := time.Now().UTC() - - restoringToDB := &databases.Database{ - Type: database.Type, - Postgresql: requestDTO.PostgresqlDatabase, - Mysql: requestDTO.MysqlDatabase, - Mariadb: requestDTO.MariadbDatabase, - Mongodb: requestDTO.MongodbDatabase, - } - - if err := restoringToDB.PopulateDbData(s.logger, s.fieldEncryptor); err != nil { - return fmt.Errorf("failed to auto-detect database data: %w", err) - } - - isExcludeExtensions := false - if requestDTO.PostgresqlDatabase != nil { - isExcludeExtensions = requestDTO.PostgresqlDatabase.IsExcludeExtensions - } - - err = s.restoreBackupUsecase.Execute( - backupConfig, - restore, - database, - restoringToDB, - backup, - storage, - isExcludeExtensions, - ) - if err != nil { - errMsg := err.Error() - restore.FailMessage = &errMsg - restore.Status = enums.RestoreStatusFailed - restore.RestoreDurationMs = time.Since(start).Milliseconds() - - if err := s.restoreRepository.Save(&restore); err != nil { - return err - } - - return err - } - - restore.Status = enums.RestoreStatusCompleted - restore.RestoreDurationMs = time.Since(start).Milliseconds() - - if err := s.restoreRepository.Save(&restore); err != nil { - return err - } - - return nil -} - func (s *RestoreService) validateVersionCompatibility( backupDatabase *databases.Database, - requestDTO RestoreBackupRequest, + requestDTO restores_core.RestoreBackupRequest, ) error { // populate version if requestDTO.MariadbDatabase != nil { @@ -372,7 +288,7 @@ func (s *RestoreService) validateVersionCompatibility( func (s *RestoreService) validateDiskSpace( backup *backups_core.Backup, - requestDTO RestoreBackupRequest, + requestDTO restores_core.RestoreBackupRequest, ) error { // Only validate disk space for PostgreSQL when file-based restore is needed: // - CPU > 1 (parallel jobs require file) diff --git a/backend/internal/features/restores/testing.go b/backend/internal/features/restores/testing.go new file mode 100644 index 0000000..d962dce --- /dev/null +++ b/backend/internal/features/restores/testing.go @@ -0,0 +1,51 @@ +package restores + +import ( + "context" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + + "databasus-backend/internal/features/backups/backups" + backups_config "databasus-backend/internal/features/backups/config" + "databasus-backend/internal/features/databases" + "databasus-backend/internal/features/restores/restoring" + workspaces_controllers "databasus-backend/internal/features/workspaces/controllers" + workspaces_testing "databasus-backend/internal/features/workspaces/testing" +) + +func CreateTestRouter() *gin.Engine { + router := workspaces_testing.CreateTestRouter( + workspaces_controllers.GetWorkspaceController(), + workspaces_controllers.GetMembershipController(), + databases.GetDatabaseController(), + backups_config.GetBackupConfigController(), + backups.GetBackupController(), + GetRestoreController(), + ) + + v1 := router.Group("/api/v1") + backups.GetBackupController().RegisterPublicRoutes(v1) + + return router +} + +func SetupMockRestoreNode(t *testing.T) (uuid.UUID, context.CancelFunc) { + nodeID := uuid.New() + err := restoring.CreateMockNodeInRegistry( + nodeID, + 100, + time.Now().UTC(), + ) + if err != nil { + t.Fatalf("Failed to create mock node: %v", err) + } + + cleanup := func() { + // Node will expire naturally from registry + } + + return nodeID, cleanup +} diff --git a/backend/internal/features/restores/usecases/mariadb/restore_backup_uc.go b/backend/internal/features/restores/usecases/mariadb/restore_backup_uc.go index 21c6d99..a7feb42 100644 --- a/backend/internal/features/restores/usecases/mariadb/restore_backup_uc.go +++ b/backend/internal/features/restores/usecases/mariadb/restore_backup_uc.go @@ -24,7 +24,7 @@ import ( "databasus-backend/internal/features/databases" mariadbtypes "databasus-backend/internal/features/databases/databases/mariadb" encryption_secrets "databasus-backend/internal/features/encryption/secrets" - "databasus-backend/internal/features/restores/models" + restores_core "databasus-backend/internal/features/restores/core" "databasus-backend/internal/features/storages" util_encryption "databasus-backend/internal/util/encryption" "databasus-backend/internal/util/tools" @@ -39,7 +39,7 @@ func (uc *RestoreMariadbBackupUsecase) Execute( originalDB *databases.Database, restoringToDB *databases.Database, backupConfig *backups_config.BackupConfig, - restore models.Restore, + restore restores_core.Restore, backup *backups_core.Backup, storage *storages.Storage, ) error { diff --git a/backend/internal/features/restores/usecases/mongodb/restore_backup_uc.go b/backend/internal/features/restores/usecases/mongodb/restore_backup_uc.go index 3266d80..c4551c9 100644 --- a/backend/internal/features/restores/usecases/mongodb/restore_backup_uc.go +++ b/backend/internal/features/restores/usecases/mongodb/restore_backup_uc.go @@ -20,7 +20,7 @@ import ( "databasus-backend/internal/features/databases" mongodbtypes "databasus-backend/internal/features/databases/databases/mongodb" encryption_secrets "databasus-backend/internal/features/encryption/secrets" - "databasus-backend/internal/features/restores/models" + restores_core "databasus-backend/internal/features/restores/core" "databasus-backend/internal/features/storages" util_encryption "databasus-backend/internal/util/encryption" "databasus-backend/internal/util/tools" @@ -39,7 +39,7 @@ func (uc *RestoreMongodbBackupUsecase) Execute( originalDB *databases.Database, restoringToDB *databases.Database, backupConfig *backups_config.BackupConfig, - restore models.Restore, + restore restores_core.Restore, backup *backups_core.Backup, storage *storages.Storage, ) error { diff --git a/backend/internal/features/restores/usecases/mysql/restore_backup_uc.go b/backend/internal/features/restores/usecases/mysql/restore_backup_uc.go index ec60596..ab182d3 100644 --- a/backend/internal/features/restores/usecases/mysql/restore_backup_uc.go +++ b/backend/internal/features/restores/usecases/mysql/restore_backup_uc.go @@ -24,7 +24,7 @@ import ( "databasus-backend/internal/features/databases" mysqltypes "databasus-backend/internal/features/databases/databases/mysql" encryption_secrets "databasus-backend/internal/features/encryption/secrets" - "databasus-backend/internal/features/restores/models" + restores_core "databasus-backend/internal/features/restores/core" "databasus-backend/internal/features/storages" util_encryption "databasus-backend/internal/util/encryption" "databasus-backend/internal/util/tools" @@ -39,7 +39,7 @@ func (uc *RestoreMysqlBackupUsecase) Execute( originalDB *databases.Database, restoringToDB *databases.Database, backupConfig *backups_config.BackupConfig, - restore models.Restore, + restore restores_core.Restore, backup *backups_core.Backup, storage *storages.Storage, ) error { diff --git a/backend/internal/features/restores/usecases/postgresql/restore_backup_uc.go b/backend/internal/features/restores/usecases/postgresql/restore_backup_uc.go index b84f155..b873db2 100644 --- a/backend/internal/features/restores/usecases/postgresql/restore_backup_uc.go +++ b/backend/internal/features/restores/usecases/postgresql/restore_backup_uc.go @@ -21,7 +21,7 @@ import ( "databasus-backend/internal/features/databases" pgtypes "databasus-backend/internal/features/databases/databases/postgresql" encryption_secrets "databasus-backend/internal/features/encryption/secrets" - "databasus-backend/internal/features/restores/models" + restores_core "databasus-backend/internal/features/restores/core" "databasus-backend/internal/features/storages" util_encryption "databasus-backend/internal/util/encryption" "databasus-backend/internal/util/tools" @@ -38,7 +38,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute( originalDB *databases.Database, restoringToDB *databases.Database, backupConfig *backups_config.BackupConfig, - restore models.Restore, + restore restores_core.Restore, backup *backups_core.Backup, storage *storages.Storage, isExcludeExtensions bool, diff --git a/backend/internal/features/restores/usecases/restore_backup_uc.go b/backend/internal/features/restores/usecases/restore_backup_uc.go index 2e746bf..13814d0 100644 --- a/backend/internal/features/restores/usecases/restore_backup_uc.go +++ b/backend/internal/features/restores/usecases/restore_backup_uc.go @@ -6,7 +6,7 @@ import ( backups_core "databasus-backend/internal/features/backups/backups/core" backups_config "databasus-backend/internal/features/backups/config" "databasus-backend/internal/features/databases" - "databasus-backend/internal/features/restores/models" + restores_core "databasus-backend/internal/features/restores/core" usecases_mariadb "databasus-backend/internal/features/restores/usecases/mariadb" usecases_mongodb "databasus-backend/internal/features/restores/usecases/mongodb" usecases_mysql "databasus-backend/internal/features/restores/usecases/mysql" @@ -23,7 +23,7 @@ type RestoreBackupUsecase struct { func (uc *RestoreBackupUsecase) Execute( backupConfig *backups_config.BackupConfig, - restore models.Restore, + restore restores_core.Restore, originalDB *databases.Database, restoringToDB *databases.Database, backup *backups_core.Backup, diff --git a/backend/internal/features/system/healthcheck/service.go b/backend/internal/features/system/healthcheck/service.go index 3a87396..cfc8558 100644 --- a/backend/internal/features/system/healthcheck/service.go +++ b/backend/internal/features/system/healthcheck/service.go @@ -37,7 +37,7 @@ func (s *HealthcheckService) IsHealthy() error { } } - if config.GetEnv().IsBackupNode { + if config.GetEnv().IsProcessingNode { if !s.backuperNode.IsBackuperRunning() { return errors.New("backuper node is not running for more than 5 minutes") } diff --git a/backend/internal/features/tasks/registry/di.go b/backend/internal/features/tasks/registry/di.go deleted file mode 100644 index 660ab71..0000000 --- a/backend/internal/features/tasks/registry/di.go +++ /dev/null @@ -1,18 +0,0 @@ -package task_registry - -import ( - cache_utils "databasus-backend/internal/util/cache" - "databasus-backend/internal/util/logger" -) - -var taskNodesRegistry = &TaskNodesRegistry{ - cache_utils.GetValkeyClient(), - logger.GetLogger(), - cache_utils.DefaultCacheTimeout, - cache_utils.NewPubSubManager(), - cache_utils.NewPubSubManager(), -} - -func GetTaskNodesRegistry() *TaskNodesRegistry { - return taskNodesRegistry -} diff --git a/backend/internal/features/tasks/registry/dto.go b/backend/internal/features/tasks/registry/dto.go deleted file mode 100644 index 59ce7cb..0000000 --- a/backend/internal/features/tasks/registry/dto.go +++ /dev/null @@ -1,29 +0,0 @@ -package task_registry - -import ( - "time" - - "github.com/google/uuid" -) - -type TaskNode struct { - ID uuid.UUID `json:"id"` - ThroughputMBs int `json:"throughputMBs"` - LastHeartbeat time.Time `json:"lastHeartbeat"` -} - -type TaskNodeStats struct { - ID uuid.UUID `json:"id"` - ActiveTasks int `json:"activeTasks"` -} - -type TaskSubmitMessage struct { - NodeID string `json:"nodeId"` - TaskID string `json:"taskId"` - IsCallNotifier bool `json:"isCallNotifier"` -} - -type TaskCompletionMessage struct { - NodeID string `json:"nodeId"` - TaskID string `json:"taskId"` -} diff --git a/backend/internal/features/tests/mariadb_backup_restore_test.go b/backend/internal/features/tests/mariadb_backup_restore_test.go index ec17585..eaf6b5f 100644 --- a/backend/internal/features/tests/mariadb_backup_restore_test.go +++ b/backend/internal/features/tests/mariadb_backup_restore_test.go @@ -21,9 +21,7 @@ import ( backups_config "databasus-backend/internal/features/backups/config" "databasus-backend/internal/features/databases" mariadbtypes "databasus-backend/internal/features/databases/databases/mariadb" - "databasus-backend/internal/features/restores" - restores_enums "databasus-backend/internal/features/restores/enums" - restores_models "databasus-backend/internal/features/restores/models" + restores_core "databasus-backend/internal/features/restores/core" "databasus-backend/internal/features/storages" users_enums "databasus-backend/internal/features/users/enums" users_testing "databasus-backend/internal/features/users/testing" @@ -213,7 +211,7 @@ func testMariadbBackupRestoreForVersion( ) restore := waitForMariadbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute) - assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status) + assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status) var tableExists int err = newDB.Get( @@ -311,7 +309,7 @@ func testMariadbBackupRestoreWithEncryptionForVersion( ) restore := waitForMariadbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute) - assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status) + assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status) var tableExists int err = newDB.Get( @@ -418,7 +416,7 @@ func testMariadbBackupRestoreWithReadOnlyUserForVersion( ) restore := waitForMariadbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute) - assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status) + assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status) var tableExists int err = newDB.Get( @@ -506,7 +504,7 @@ func createMariadbRestoreViaAPI( version tools.MariadbVersion, token string, ) { - request := restores.RestoreBackupRequest{ + request := restores_core.RestoreBackupRequest{ MariadbDatabase: &mariadbtypes.MariadbDatabase{ Host: host, Port: port, @@ -533,7 +531,7 @@ func waitForMariadbRestoreCompletion( backupID uuid.UUID, token string, timeout time.Duration, -) *restores_models.Restore { +) *restores_core.Restore { startTime := time.Now() pollInterval := 500 * time.Millisecond @@ -542,7 +540,7 @@ func waitForMariadbRestoreCompletion( t.Fatalf("Timeout waiting for MariaDB restore completion after %v", timeout) } - var restoresList []*restores_models.Restore + var restoresList []*restores_core.Restore test_utils.MakeGetRequestAndUnmarshal( t, router, @@ -553,10 +551,10 @@ func waitForMariadbRestoreCompletion( ) for _, restore := range restoresList { - if restore.Status == restores_enums.RestoreStatusCompleted { + if restore.Status == restores_core.RestoreStatusCompleted { return restore } - if restore.Status == restores_enums.RestoreStatusFailed { + if restore.Status == restores_core.RestoreStatusFailed { failMsg := "unknown error" if restore.FailMessage != nil { failMsg = *restore.FailMessage diff --git a/backend/internal/features/tests/mongodb_backup_restore_test.go b/backend/internal/features/tests/mongodb_backup_restore_test.go index 9c779b6..2abd505 100644 --- a/backend/internal/features/tests/mongodb_backup_restore_test.go +++ b/backend/internal/features/tests/mongodb_backup_restore_test.go @@ -23,9 +23,7 @@ import ( backups_config "databasus-backend/internal/features/backups/config" "databasus-backend/internal/features/databases" mongodbtypes "databasus-backend/internal/features/databases/databases/mongodb" - "databasus-backend/internal/features/restores" - restores_enums "databasus-backend/internal/features/restores/enums" - restores_models "databasus-backend/internal/features/restores/models" + restores_core "databasus-backend/internal/features/restores/core" "databasus-backend/internal/features/storages" users_enums "databasus-backend/internal/features/users/enums" users_testing "databasus-backend/internal/features/users/testing" @@ -175,7 +173,7 @@ func testMongodbBackupRestoreForVersion( ) restore := waitForMongodbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute) - assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status) + assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status) verifyMongodbDataIntegrity(t, container, newDBName) @@ -254,7 +252,7 @@ func testMongodbBackupRestoreWithEncryptionForVersion( ) restore := waitForMongodbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute) - assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status) + assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status) verifyMongodbDataIntegrity(t, container, newDBName) @@ -342,7 +340,7 @@ func testMongodbBackupRestoreWithReadOnlyUserForVersion( ) restore := waitForMongodbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute) - assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status) + assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status) verifyMongodbDataIntegrity(t, container, newDBName) @@ -431,7 +429,7 @@ func createMongodbRestoreViaAPI( version tools.MongodbVersion, token string, ) { - request := restores.RestoreBackupRequest{ + request := restores_core.RestoreBackupRequest{ MongodbDatabase: &mongodbtypes.MongodbDatabase{ Host: host, Port: port, @@ -461,7 +459,7 @@ func waitForMongodbRestoreCompletion( backupID uuid.UUID, token string, timeout time.Duration, -) *restores_models.Restore { +) *restores_core.Restore { startTime := time.Now() pollInterval := 500 * time.Millisecond @@ -470,7 +468,7 @@ func waitForMongodbRestoreCompletion( t.Fatalf("Timeout waiting for MongoDB restore completion after %v", timeout) } - var restoresList []*restores_models.Restore + var restoresList []*restores_core.Restore test_utils.MakeGetRequestAndUnmarshal( t, router, @@ -481,10 +479,10 @@ func waitForMongodbRestoreCompletion( ) for _, restore := range restoresList { - if restore.Status == restores_enums.RestoreStatusCompleted { + if restore.Status == restores_core.RestoreStatusCompleted { return restore } - if restore.Status == restores_enums.RestoreStatusFailed { + if restore.Status == restores_core.RestoreStatusFailed { failMsg := "unknown error" if restore.FailMessage != nil { failMsg = *restore.FailMessage diff --git a/backend/internal/features/tests/mysql_backup_restore_test.go b/backend/internal/features/tests/mysql_backup_restore_test.go index 5c302ed..27b8089 100644 --- a/backend/internal/features/tests/mysql_backup_restore_test.go +++ b/backend/internal/features/tests/mysql_backup_restore_test.go @@ -21,9 +21,7 @@ import ( backups_config "databasus-backend/internal/features/backups/config" "databasus-backend/internal/features/databases" mysqltypes "databasus-backend/internal/features/databases/databases/mysql" - "databasus-backend/internal/features/restores" - restores_enums "databasus-backend/internal/features/restores/enums" - restores_models "databasus-backend/internal/features/restores/models" + restores_core "databasus-backend/internal/features/restores/core" "databasus-backend/internal/features/storages" users_enums "databasus-backend/internal/features/users/enums" users_testing "databasus-backend/internal/features/users/testing" @@ -188,7 +186,7 @@ func testMysqlBackupRestoreForVersion(t *testing.T, mysqlVersion tools.MysqlVers ) restore := waitForMysqlRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute) - assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status) + assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status) var tableExists int err = newDB.Get( @@ -286,7 +284,7 @@ func testMysqlBackupRestoreWithEncryptionForVersion( ) restore := waitForMysqlRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute) - assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status) + assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status) var tableExists int err = newDB.Get( @@ -393,7 +391,7 @@ func testMysqlBackupRestoreWithReadOnlyUserForVersion( ) restore := waitForMysqlRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute) - assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status) + assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status) var tableExists int err = newDB.Get( @@ -481,7 +479,7 @@ func createMysqlRestoreViaAPI( version tools.MysqlVersion, token string, ) { - request := restores.RestoreBackupRequest{ + request := restores_core.RestoreBackupRequest{ MysqlDatabase: &mysqltypes.MysqlDatabase{ Host: host, Port: port, @@ -508,7 +506,7 @@ func waitForMysqlRestoreCompletion( backupID uuid.UUID, token string, timeout time.Duration, -) *restores_models.Restore { +) *restores_core.Restore { startTime := time.Now() pollInterval := 500 * time.Millisecond @@ -517,7 +515,7 @@ func waitForMysqlRestoreCompletion( t.Fatalf("Timeout waiting for MySQL restore completion after %v", timeout) } - var restoresList []*restores_models.Restore + var restoresList []*restores_core.Restore test_utils.MakeGetRequestAndUnmarshal( t, router, @@ -528,10 +526,10 @@ func waitForMysqlRestoreCompletion( ) for _, restore := range restoresList { - if restore.Status == restores_enums.RestoreStatusCompleted { + if restore.Status == restores_core.RestoreStatusCompleted { return restore } - if restore.Status == restores_enums.RestoreStatusFailed { + if restore.Status == restores_core.RestoreStatusFailed { failMsg := "unknown error" if restore.FailMessage != nil { failMsg = *restore.FailMessage diff --git a/backend/internal/features/tests/postgresql_backup_restore_test.go b/backend/internal/features/tests/postgresql_backup_restore_test.go index dd69317..dbc19ac 100644 --- a/backend/internal/features/tests/postgresql_backup_restore_test.go +++ b/backend/internal/features/tests/postgresql_backup_restore_test.go @@ -23,8 +23,7 @@ import ( "databasus-backend/internal/features/databases" pgtypes "databasus-backend/internal/features/databases/databases/postgresql" "databasus-backend/internal/features/restores" - restores_enums "databasus-backend/internal/features/restores/enums" - restores_models "databasus-backend/internal/features/restores/models" + restores_core "databasus-backend/internal/features/restores/core" "databasus-backend/internal/features/storages" users_enums "databasus-backend/internal/features/users/enums" users_testing "databasus-backend/internal/features/users/testing" @@ -212,7 +211,7 @@ func Test_BackupAndRestoreSupabase_PublicSchemaOnly_RestoreIsSuccessful(t *testi ) restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute) - assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status) + assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status) var countAfterRestore int err = supabaseDB.Get( @@ -439,7 +438,7 @@ func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string, cp ) restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute) - assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status) + assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status) var tableExists bool err = newDB.Get( @@ -555,7 +554,7 @@ func testSchemaSelectionAllSchemasForVersion(t *testing.T, pgVersion string, por ) restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute) - assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status) + assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status) var publicTableExists bool err = newDB.Get(&publicTableExists, ` @@ -689,7 +688,7 @@ func testBackupRestoreWithExcludeExtensionsForVersion(t *testing.T, pgVersion st ) restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute) - assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status) + assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status) // Verify the table was restored var tableExists bool @@ -829,7 +828,7 @@ func testBackupRestoreWithoutExcludeExtensionsForVersion( ) restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute) - assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status) + assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status) // Verify the extension was recovered var extensionExists bool @@ -956,7 +955,7 @@ func testBackupRestoreWithReadOnlyUserForVersion(t *testing.T, pgVersion string, ) restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute) - assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status) + assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status) var tableExists bool err = newDB.Get( @@ -1076,7 +1075,7 @@ func testSchemaSelectionOnlySpecifiedSchemasForVersion( ) restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute) - assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status) + assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status) var publicTableExists bool err = newDB.Get(&publicTableExists, ` @@ -1190,7 +1189,7 @@ func testBackupRestoreWithEncryptionForVersion(t *testing.T, pgVersion string, p ) restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute) - assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status) + assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status) var tableExists bool err = newDB.Get( @@ -1286,7 +1285,7 @@ func waitForRestoreCompletion( backupID uuid.UUID, token string, timeout time.Duration, -) *restores_models.Restore { +) *restores_core.Restore { startTime := time.Now() pollInterval := 500 * time.Millisecond @@ -1295,7 +1294,7 @@ func waitForRestoreCompletion( t.Fatalf("Timeout waiting for restore completion after %v", timeout) } - var restores []*restores_models.Restore + var restores []*restores_core.Restore test_utils.MakeGetRequestAndUnmarshal( t, router, @@ -1306,10 +1305,10 @@ func waitForRestoreCompletion( ) for _, restore := range restores { - if restore.Status == restores_enums.RestoreStatusCompleted { + if restore.Status == restores_core.RestoreStatusCompleted { return restore } - if restore.Status == restores_enums.RestoreStatusFailed { + if restore.Status == restores_core.RestoreStatusFailed { failMsg := "unknown error" if restore.FailMessage != nil { failMsg = *restore.FailMessage @@ -1476,7 +1475,7 @@ func createRestoreWithCpuCountViaAPI( cpuCount int, token string, ) { - request := restores.RestoreBackupRequest{ + request := restores_core.RestoreBackupRequest{ PostgresqlDatabase: &pgtypes.PostgresqlDatabase{ Host: host, Port: port, @@ -1509,7 +1508,7 @@ func createRestoreWithOptionsViaAPI( isExcludeExtensions bool, token string, ) { - request := restores.RestoreBackupRequest{ + request := restores_core.RestoreBackupRequest{ PostgresqlDatabase: &pgtypes.PostgresqlDatabase{ Host: host, Port: port, @@ -1647,7 +1646,7 @@ func createSupabaseRestoreViaAPI( database string, token string, ) { - request := restores.RestoreBackupRequest{ + request := restores_core.RestoreBackupRequest{ PostgresqlDatabase: &pgtypes.PostgresqlDatabase{ Host: host, Port: port, diff --git a/backend/internal/features/tests/setup_test.go b/backend/internal/features/tests/setup_test.go index 9ee5dbf..37c8420 100644 --- a/backend/internal/features/tests/setup_test.go +++ b/backend/internal/features/tests/setup_test.go @@ -5,6 +5,7 @@ import ( "testing" "databasus-backend/internal/features/backups/backups/backuping" + "databasus-backend/internal/features/restores/restoring" cache_utils "databasus-backend/internal/util/cache" ) @@ -12,11 +13,15 @@ func TestMain(m *testing.M) { cache_utils.ClearAllCache() backuperNode := backuping.CreateTestBackuperNode() - cancel := backuping.StartBackuperNodeForTest(&testing.T{}, backuperNode) + cancelBackup := backuping.StartBackuperNodeForTest(&testing.T{}, backuperNode) + + restorerNode := restoring.CreateTestRestorerNode() + cancelRestore := restoring.StartRestorerNodeForTest(&testing.T{}, restorerNode) exitCode := m.Run() - backuping.StopBackuperNodeForTest(&testing.T{}, cancel, backuperNode) + backuping.StopBackuperNodeForTest(&testing.T{}, cancelBackup, backuperNode) + restoring.StopRestorerNodeForTest(&testing.T{}, cancelRestore, restorerNode) os.Exit(exitCode) } diff --git a/backend/internal/util/cache/cache_test.go b/backend/internal/util/cache/cache_test.go index 4075a05..ccbe041 100644 --- a/backend/internal/util/cache/cache_test.go +++ b/backend/internal/util/cache/cache_test.go @@ -1,7 +1,9 @@ package cache_utils import ( + "context" "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -49,3 +51,43 @@ func Test_ClearAllCache_AfterClear_CacheIsEmpty(t *testing.T) { assert.Nil(t, retrieved, "Key %s should be deleted after clearing", tk.prefix+tk.key) } } + +func Test_SetWithExpiration_SetsCorrectTTL(t *testing.T) { + client := getCache() + + // Create a cache utility + testPrefix := "test:ttl:" + cacheUtil := NewCacheUtil[string](client, testPrefix) + + // Set a value with 1-hour expiration + testKey := "key1" + testValue := "test value" + oneHour := 1 * time.Hour + + cacheUtil.SetWithExpiration(testKey, &testValue, oneHour) + + // Verify the value was set + retrieved := cacheUtil.Get(testKey) + assert.NotNil(t, retrieved, "Value should be stored") + assert.Equal(t, testValue, *retrieved, "Retrieved value should match") + + // Check the TTL using Valkey TTL command + ctx, cancel := context.WithTimeout(context.Background(), DefaultCacheTimeout) + defer cancel() + + fullKey := testPrefix + testKey + ttlResult := client.Do(ctx, client.B().Ttl().Key(fullKey).Build()) + assert.NoError(t, ttlResult.Error(), "TTL command should not error") + + ttlSeconds, err := ttlResult.AsInt64() + assert.NoError(t, err, "TTL should be retrievable as int64") + + // TTL should be approximately 1 hour (3600 seconds) + // Allow for a small margin (within 10 seconds of 3600) + expectedTTL := int64(3600) + assert.GreaterOrEqual(t, ttlSeconds, expectedTTL-10, "TTL should be close to 1 hour") + assert.LessOrEqual(t, ttlSeconds, expectedTTL, "TTL should not exceed 1 hour") + + // Clean up + cacheUtil.Invalidate(testKey) +} diff --git a/backend/internal/util/cache/utils.go b/backend/internal/util/cache/utils.go index b4b0ca9..a16aaaa 100644 --- a/backend/internal/util/cache/utils.go +++ b/backend/internal/util/cache/utils.go @@ -67,6 +67,43 @@ func (c *CacheUtil[T]) Set(key string, item *T) { c.client.Do(ctx, c.client.B().Set().Key(fullKey).Value(string(data)).Ex(c.expiry).Build()) } +func (c *CacheUtil[T]) SetWithExpiration(key string, item *T, expiry time.Duration) { + ctx, cancel := context.WithTimeout(context.Background(), c.timeout) + defer cancel() + + data, err := json.Marshal(item) + if err != nil { + return + } + + fullKey := c.prefix + key + c.client.Do(ctx, c.client.B().Set().Key(fullKey).Value(string(data)).Ex(expiry).Build()) +} + +func (c *CacheUtil[T]) GetAndDelete(key string) *T { + ctx, cancel := context.WithTimeout(context.Background(), c.timeout) + defer cancel() + + fullKey := c.prefix + key + result := c.client.Do(ctx, c.client.B().Getdel().Key(fullKey).Build()) + + if result.Error() != nil { + return nil + } + + data, err := result.AsBytes() + if err != nil { + return nil + } + + var item T + if err := json.Unmarshal(data, &item); err != nil { + return nil + } + + return &item +} + func (c *CacheUtil[T]) Invalidate(key string) { ctx, cancel := context.WithTimeout(context.Background(), c.timeout) defer cancel()