FEATURE (restores): Add support of multiple restores nodes

This commit is contained in:
Rostislav Dugin
2026-01-17 13:59:06 +03:00
parent d98baa0656
commit c39bd34d5e
49 changed files with 5195 additions and 1253 deletions

View File

@@ -7,6 +7,7 @@ This is NOT a strict set of rules, but a set of recommendations to help you writ
## Table of Contents
- [Engineering Philosophy](#engineering-philosophy)
- [Backend Guidelines](#backend-guidelines)
- [Code Style](#code-style)
- [Comments](#comments)
@@ -22,6 +23,67 @@ This is NOT a strict set of rules, but a set of recommendations to help you writ
---
## Engineering Philosophy
**Think like a skeptical senior engineer and code reviewer. Don't just do what was asked—also think about what should have been asked.**
⚠️ **Balance vigilance with pragmatism:** Catch real issues, not theoretical ones. Don't let perfect be the enemy of good.
### Task Context Assessment:
**First, assess the task scope:**
- **Trivial** (typos, formatting, simple field adds): Apply directly with minimal analysis
- **Standard** (CRUD, typical features): Brief assumption check, proceed
- **Complex** (architecture, security, performance-critical): Full analysis required
- **Unclear** (ambiguous requirements): Always clarify assumptions first
### For Non-Trivial Tasks:
1. **Restate the objective and list assumptions** (explicit + implicit)
- If any assumption is shaky, call it out clearly
- Distinguish between what's specified and what you're inferring
2. **Propose appropriate solutions:**
- For complex tasks: 23 viable approaches (including a simpler baseline)
- Recommend one with clear tradeoffs
- Consider: complexity, maintainability, performance, future extensibility
3. **Identify risks proactively:**
- Edge cases and boundary conditions
- Security/privacy pitfalls
- Performance risks and scalability concerns
- Operational concerns (deployment, observability, rollback, monitoring)
4. **Handle ambiguity:**
- If requirements are ambiguous, make a reasonable default and proceed
- Clearly label your assumptions
- Document what would change under alternative assumptions
5. **Deliver quality:**
- Provide a solution that is correct, testable, and maintainable
- Include minimal tests or validation steps
- Follow project testing philosophy: prefer controller tests over unit tests
- Follow all project guidelines from this document
6. **Self-review before finalizing:**
- Ask: "What could go wrong?"
- Patch the answer accordingly
- Verify edge cases are handled
### Application Guidelines:
**Scale your response to the task:**
- **Trivial changes:** Steps 5-6 only (deliver quality + self-review)
- **Standard features:** Steps 1, 5-6 (restate + deliver + review)
- **Complex/risky changes:** All steps 1-6
- **Ambiguous requests:** Steps 1, 4 mandatory
**Be proportionally thorough—brief for simple tasks, comprehensive for risky ones. Avoid analysis paralysis.**
---
## Backend Guidelines
### Code Style

View File

@@ -25,10 +25,10 @@ import (
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/restores"
"databasus-backend/internal/features/restores/restoring"
"databasus-backend/internal/features/storages"
system_healthcheck "databasus-backend/internal/features/system/healthcheck"
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
task_registry "databasus-backend/internal/features/tasks/registry"
users_controllers "databasus-backend/internal/features/users/controllers"
users_middleware "databasus-backend/internal/features/users/middleware"
users_services "databasus-backend/internal/features/users/services"
@@ -273,7 +273,7 @@ func runBackgroundTasks(log *slog.Logger) {
})
go runWithPanicLogging(log, "restore background service", func() {
restores.GetRestoreBackgroundService().Run(ctx)
restoring.GetRestoresScheduler().Run(ctx)
})
go runWithPanicLogging(log, "healthcheck attempt background service", func() {
@@ -288,21 +288,29 @@ func runBackgroundTasks(log *slog.Logger) {
backups_download.GetDownloadTokenBackgroundService().Run(ctx)
})
go runWithPanicLogging(log, "task nodes registry background service", func() {
task_registry.GetTaskNodesRegistry().Run(ctx)
go runWithPanicLogging(log, "backup nodes registry background service", func() {
backuping.GetBackupNodesRegistry().Run(ctx)
})
go runWithPanicLogging(log, "restore nodes registry background service", func() {
restoring.GetRestoreNodesRegistry().Run(ctx)
})
} else {
log.Info("Skipping primary node tasks as not primary node")
}
if config.GetEnv().IsBackupNode {
if config.GetEnv().IsProcessingNode {
log.Info("Starting backup node background tasks...")
go runWithPanicLogging(log, "backup node", func() {
backuping.GetBackuperNode().Run(ctx)
})
go runWithPanicLogging(log, "restore node", func() {
restoring.GetRestorerNode().Run(ctx)
})
} else {
log.Info("Skipping backup node tasks as not backup node")
log.Info("Skipping backup/restore node tasks as not backup node")
}
}

View File

@@ -9,7 +9,6 @@ import (
"strings"
"sync"
"github.com/google/uuid"
"github.com/ilyakaznacheev/cleanenv"
"github.com/joho/godotenv"
)
@@ -32,10 +31,9 @@ type EnvVariables struct {
ShowDbInstallationVerificationLogs bool `env:"SHOW_DB_INSTALLATION_VERIFICATION_LOGS"`
NodeID string
IsManyNodesMode bool `env:"IS_MANY_NODES_MODE"`
IsPrimaryNode bool `env:"IS_PRIMARY_NODE"`
IsBackupNode bool `env:"IS_BACKUP_NODE"`
IsProcessingNode bool `env:"IS_PROCESSING_NODE"`
NodeNetworkThroughputMBs int `env:"NODE_NETWORK_THROUGHPUT_MBPS"`
DataFolder string
@@ -230,14 +228,13 @@ func loadEnvVariables() {
env.ShowDbInstallationVerificationLogs,
)
env.NodeID = uuid.New().String()
if env.NodeNetworkThroughputMBs == 0 {
env.NodeNetworkThroughputMBs = 125 // 1 Gbit/s
}
if !env.IsManyNodesMode {
env.IsPrimaryNode = true
env.IsBackupNode = true
env.IsProcessingNode = true
}
// Valkey

View File

@@ -8,7 +8,6 @@ import (
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/storages"
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
task_registry "databasus-backend/internal/features/tasks/registry"
workspaces_services "databasus-backend/internal/features/workspaces/services"
util_encryption "databasus-backend/internal/util/encryption"
"errors"
@@ -35,7 +34,7 @@ type BackuperNode struct {
storageService *storages.StorageService
notificationSender backups_core.NotificationSender
backupCancelManager *tasks_cancellation.TaskCancelManager
tasksRegistry *task_registry.TaskNodesRegistry
backupNodesRegistry *BackupNodesRegistry
logger *slog.Logger
createBackupUseCase backups_core.CreateBackupUsecase
nodeID uuid.UUID
@@ -48,19 +47,20 @@ func (n *BackuperNode) Run(ctx context.Context) {
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
backupNode := task_registry.TaskNode{
backupNode := BackupNode{
ID: n.nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: time.Now().UTC(),
}
if err := n.tasksRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
n.logger.Error("Failed to register node in registry", "error", err)
panic(err)
}
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
n.MakeBackup(backupID, isCallNotifier)
if err := n.tasksRegistry.PublishTaskCompletion(n.nodeID.String(), backupID); err != nil {
if err := n.backupNodesRegistry.PublishBackupCompletion(n.nodeID, backupID); err != nil {
n.logger.Error(
"Failed to publish backup completion",
"error",
@@ -71,12 +71,13 @@ func (n *BackuperNode) Run(ctx context.Context) {
}
}
if err := n.tasksRegistry.SubscribeNodeForTasksAssignment(n.nodeID.String(), backupHandler); err != nil {
err := n.backupNodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID, backupHandler)
if err != nil {
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
panic(err)
}
defer func() {
if err := n.tasksRegistry.UnsubscribeNodeForTasksAssignments(); err != nil {
if err := n.backupNodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
}
}()
@@ -91,7 +92,7 @@ func (n *BackuperNode) Run(ctx context.Context) {
case <-ctx.Done():
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
if err := n.tasksRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
if err := n.backupNodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
n.logger.Error("Failed to unregister node from registry", "error", err)
}
@@ -357,9 +358,9 @@ func (n *BackuperNode) SendBackupNotification(
}
}
func (n *BackuperNode) sendHeartbeat(backupNode *task_registry.TaskNode) {
func (n *BackuperNode) sendHeartbeat(backupNode *BackupNode) {
n.lastHeartbeat = time.Now().UTC()
if err := n.tasksRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *backupNode); err != nil {
if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *backupNode); err != nil {
n.logger.Error("Failed to send heartbeat", "error", err)
}
}

View File

@@ -1,7 +1,6 @@
package backuping
import (
"databasus-backend/internal/config"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/usecases"
backups_config "databasus-backend/internal/features/backups/config"
@@ -9,8 +8,8 @@ import (
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
task_registry "databasus-backend/internal/features/tasks/registry"
workspaces_services "databasus-backend/internal/features/workspaces/services"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
"time"
@@ -22,16 +21,16 @@ var backupRepository = &backups_core.BackupRepository{}
var taskCancelManager = tasks_cancellation.GetTaskCancelManager()
var nodesRegistry = task_registry.GetTaskNodesRegistry()
var backupNodesRegistry = &BackupNodesRegistry{
cache_utils.GetValkeyClient(),
logger.GetLogger(),
cache_utils.DefaultCacheTimeout,
cache_utils.NewPubSubManager(),
cache_utils.NewPubSubManager(),
}
func getNodeID() uuid.UUID {
nodeIDStr := config.GetEnv().NodeID
nodeID, err := uuid.Parse(nodeIDStr)
if err != nil {
logger.GetLogger().Error("Failed to parse node ID from config", "error", err)
panic(err)
}
return nodeID
return uuid.New()
}
var backuperNode = &BackuperNode{
@@ -43,7 +42,7 @@ var backuperNode = &BackuperNode{
storages.GetStorageService(),
notifiers.GetNotifierService(),
taskCancelManager,
nodesRegistry,
backupNodesRegistry,
logger.GetLogger(),
usecases.GetCreateBackupUsecase(),
getNodeID(),
@@ -51,15 +50,15 @@ var backuperNode = &BackuperNode{
}
var backupsScheduler = &BackupsScheduler{
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
taskCancelManager,
nodesRegistry,
time.Now().UTC(),
logger.GetLogger(),
make(map[uuid.UUID]BackupToNodeRelation),
backuperNode,
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
taskCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
lastBackupTime: time.Now().UTC(),
logger: logger.GetLogger(),
backupToNodeRelations: make(map[uuid.UUID]BackupToNodeRelation),
backuperNode: backuperNode,
}
func GetBackupsScheduler() *BackupsScheduler {
@@ -69,3 +68,7 @@ func GetBackupsScheduler() *BackupsScheduler {
func GetBackuperNode() *BackuperNode {
return backuperNode
}
func GetBackupNodesRegistry() *BackupNodesRegistry {
return backupNodesRegistry
}

View File

@@ -1,8 +1,34 @@
package backuping
import "github.com/google/uuid"
import (
"time"
"github.com/google/uuid"
)
type BackupToNodeRelation struct {
NodeID uuid.UUID `json:"nodeId"`
BackupsIDs []uuid.UUID `json:"backupsIds"`
}
type BackupNode struct {
ID uuid.UUID `json:"id"`
ThroughputMBs int `json:"throughputMBs"`
LastHeartbeat time.Time `json:"lastHeartbeat"`
}
type BackupNodeStats struct {
ID uuid.UUID `json:"id"`
ActiveBackups int `json:"activeBackups"`
}
type BackupSubmitMessage struct {
NodeID uuid.UUID `json:"nodeId"`
BackupID uuid.UUID `json:"backupId"`
IsCallNotifier bool `json:"isCallNotifier"`
}
type BackupCompletionMessage struct {
NodeID uuid.UUID `json:"nodeId"`
BackupID uuid.UUID `json:"backupId"`
}

View File

@@ -1,4 +1,4 @@
package task_registry
package backuping
import (
"context"
@@ -15,45 +15,41 @@ import (
)
const (
nodeInfoKeyPrefix = "node:"
nodeInfoKeySuffix = ":info"
nodeActiveTasksPrefix = "node:"
nodeActiveTasksSuffix = ":active_tasks"
taskSubmitChannel = "task:submit"
taskCompletionChannel = "task:completion"
nodeInfoKeyPrefix = "backup:node:"
nodeInfoKeySuffix = ":info"
nodeActiveBackupsPrefix = "backup:node:"
nodeActiveBackupsSuffix = ":active_backups"
backupSubmitChannel = "backup:submit"
backupCompletionChannel = "backup:completion"
deadNodeThreshold = 2 * time.Minute
cleanupTickerInterval = 1 * time.Second
)
// TaskNodesRegistry helps to sync tasks scheduler (backuping or restoring)
// and task nodes which are used for network-intensive tasks processing
// BackupNodesRegistry helps to sync backups scheduler and backup nodes.
//
// Features:
// - Track node availability and load level
// - Assign from scheduler to node tasks needed to be processed
// - Notify scheduler from node about task completion
// - Assign from scheduler to node backups needed to be processed
// - Notify scheduler from node about backup completion
//
// Important things to remember:
// - Node can contain different tasks types so when task is assigned
// or node's tasks cleaned - should be performed DB check in DB
// that task with this ID exists for this task type at all
// - Nodes without heathbeat for more than 2 minutes are not included
// - Nodes without heartbeat for more than 2 minutes are not included
// in available nodes list and stats
//
// Cleanup dead nodes performed on 2 levels:
// - List and stats functions do not return dead nodes
// - Periodically dead nodes are cleaned up in cache (to not
// accumulate too many dead nodes in cache)
type TaskNodesRegistry struct {
type BackupNodesRegistry struct {
client valkey.Client
logger *slog.Logger
timeout time.Duration
pubsubTasks *cache_utils.PubSubManager
pubsubBackups *cache_utils.PubSubManager
pubsubCompletions *cache_utils.PubSubManager
}
func (r *TaskNodesRegistry) Run(ctx context.Context) {
func (r *BackupNodesRegistry) Run(ctx context.Context) {
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
}
@@ -72,7 +68,7 @@ func (r *TaskNodesRegistry) Run(ctx context.Context) {
}
}
func (r *TaskNodesRegistry) GetAvailableNodes() ([]TaskNode, error) {
func (r *BackupNodesRegistry) GetAvailableNodes() ([]BackupNode, error) {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
@@ -104,7 +100,7 @@ func (r *TaskNodesRegistry) GetAvailableNodes() ([]TaskNode, error) {
}
if len(allKeys) == 0 {
return []TaskNode{}, nil
return []BackupNode{}, nil
}
keyDataMap, err := r.pipelineGetKeys(allKeys)
@@ -113,14 +109,15 @@ func (r *TaskNodesRegistry) GetAvailableNodes() ([]TaskNode, error) {
}
threshold := time.Now().UTC().Add(-deadNodeThreshold)
var nodes []TaskNode
var nodes []BackupNode
for key, data := range keyDataMap {
// Skip if the key doesn't exist (data is empty)
if len(data) == 0 {
continue
}
var node TaskNode
var node BackupNode
if err := json.Unmarshal(data, &node); err != nil {
r.logger.Warn("Failed to unmarshal node data", "key", key, "error", err)
continue
@@ -141,13 +138,13 @@ func (r *TaskNodesRegistry) GetAvailableNodes() ([]TaskNode, error) {
return nodes, nil
}
func (r *TaskNodesRegistry) GetNodesStats() ([]TaskNodeStats, error) {
func (r *BackupNodesRegistry) GetBackupNodesStats() ([]BackupNodeStats, error) {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
var allKeys []string
cursor := uint64(0)
pattern := nodeActiveTasksPrefix + "*" + nodeActiveTasksSuffix
pattern := nodeActiveBackupsPrefix + "*" + nodeActiveBackupsSuffix
for {
result := r.client.Do(
@@ -156,7 +153,7 @@ func (r *TaskNodesRegistry) GetNodesStats() ([]TaskNodeStats, error) {
)
if result.Error() != nil {
return nil, fmt.Errorf("failed to scan active tasks keys: %w", result.Error())
return nil, fmt.Errorf("failed to scan active backups keys: %w", result.Error())
}
scanResult, err := result.AsScanEntry()
@@ -173,18 +170,18 @@ func (r *TaskNodesRegistry) GetNodesStats() ([]TaskNodeStats, error) {
}
if len(allKeys) == 0 {
return []TaskNodeStats{}, nil
return []BackupNodeStats{}, nil
}
keyDataMap, err := r.pipelineGetKeys(allKeys)
if err != nil {
return nil, fmt.Errorf("failed to pipeline get active tasks keys: %w", err)
return nil, fmt.Errorf("failed to pipeline get active backups keys: %w", err)
}
var nodeInfoKeys []string
nodeIDToStatsKey := make(map[string]string)
for key := range keyDataMap {
nodeID := r.extractNodeIDFromKey(key, nodeActiveTasksPrefix, nodeActiveTasksSuffix)
nodeID := r.extractNodeIDFromKey(key, nodeActiveBackupsPrefix, nodeActiveBackupsSuffix)
nodeIDStr := nodeID.String()
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeIDStr, nodeInfoKeySuffix)
nodeInfoKeys = append(nodeInfoKeys, infoKey)
@@ -197,14 +194,14 @@ func (r *TaskNodesRegistry) GetNodesStats() ([]TaskNodeStats, error) {
}
threshold := time.Now().UTC().Add(-deadNodeThreshold)
var stats []TaskNodeStats
var stats []BackupNodeStats
for infoKey, nodeData := range nodeInfoMap {
// Skip if the info key doesn't exist (nodeData is empty)
if len(nodeData) == 0 {
continue
}
var node TaskNode
var node BackupNode
if err := json.Unmarshal(nodeData, &node); err != nil {
r.logger.Warn("Failed to unmarshal node data", "key", infoKey, "error", err)
continue
@@ -223,13 +220,13 @@ func (r *TaskNodesRegistry) GetNodesStats() ([]TaskNodeStats, error) {
tasksData := keyDataMap[statsKey]
count, err := r.parseIntFromBytes(tasksData)
if err != nil {
r.logger.Warn("Failed to parse active tasks count", "key", statsKey, "error", err)
r.logger.Warn("Failed to parse active backups count", "key", statsKey, "error", err)
continue
}
stat := TaskNodeStats{
ID: node.ID,
ActiveTasks: int(count),
stat := BackupNodeStats{
ID: node.ID,
ActiveBackups: int(count),
}
stats = append(stats, stat)
}
@@ -237,16 +234,16 @@ func (r *TaskNodesRegistry) GetNodesStats() ([]TaskNodeStats, error) {
return stats, nil
}
func (r *TaskNodesRegistry) IncrementTasksInProgress(nodeID string) error {
func (r *BackupNodesRegistry) IncrementBackupsInProgress(nodeID uuid.UUID) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
key := fmt.Sprintf("%s%s%s", nodeActiveTasksPrefix, nodeID, nodeActiveTasksSuffix)
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID.String(), nodeActiveBackupsSuffix)
result := r.client.Do(ctx, r.client.B().Incr().Key(key).Build())
if result.Error() != nil {
return fmt.Errorf(
"failed to increment tasks in progress for node %s: %w",
"failed to increment backups in progress for node %s: %w",
nodeID,
result.Error(),
)
@@ -255,16 +252,16 @@ func (r *TaskNodesRegistry) IncrementTasksInProgress(nodeID string) error {
return nil
}
func (r *TaskNodesRegistry) DecrementTasksInProgress(nodeID string) error {
func (r *BackupNodesRegistry) DecrementBackupsInProgress(nodeID uuid.UUID) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
key := fmt.Sprintf("%s%s%s", nodeActiveTasksPrefix, nodeID, nodeActiveTasksSuffix)
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID.String(), nodeActiveBackupsSuffix)
result := r.client.Do(ctx, r.client.B().Decr().Key(key).Build())
if result.Error() != nil {
return fmt.Errorf(
"failed to decrement tasks in progress for node %s: %w",
"failed to decrement backups in progress for node %s: %w",
nodeID,
result.Error(),
)
@@ -279,13 +276,13 @@ func (r *TaskNodesRegistry) DecrementTasksInProgress(nodeID string) error {
setCtx, setCancel := context.WithTimeout(context.Background(), r.timeout)
r.client.Do(setCtx, r.client.B().Set().Key(key).Value("0").Build())
setCancel()
r.logger.Warn("Active tasks counter went below 0, reset to 0", "nodeID", nodeID)
r.logger.Warn("Active backups counter went below 0, reset to 0", "nodeID", nodeID)
}
return nil
}
func (r *TaskNodesRegistry) HearthbeatNodeInRegistry(now time.Time, node TaskNode) error {
func (r *BackupNodesRegistry) HearthbeatNodeInRegistry(now time.Time, backupNode BackupNode) error {
if now.IsZero() {
return fmt.Errorf("cannot register node with zero heartbeat timestamp")
}
@@ -293,36 +290,36 @@ func (r *TaskNodesRegistry) HearthbeatNodeInRegistry(now time.Time, node TaskNod
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
node.LastHeartbeat = now
backupNode.LastHeartbeat = now
data, err := json.Marshal(node)
data, err := json.Marshal(backupNode)
if err != nil {
return fmt.Errorf("failed to marshal node: %w", err)
return fmt.Errorf("failed to marshal backup node: %w", err)
}
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node.ID.String(), nodeInfoKeySuffix)
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, backupNode.ID.String(), nodeInfoKeySuffix)
result := r.client.Do(
ctx,
r.client.B().Set().Key(key).Value(string(data)).Build(),
)
if result.Error() != nil {
return fmt.Errorf("failed to register node %s: %w", node.ID, result.Error())
return fmt.Errorf("failed to register node %s: %w", backupNode.ID, result.Error())
}
return nil
}
func (r *TaskNodesRegistry) UnregisterNodeFromRegistry(node TaskNode) error {
func (r *BackupNodesRegistry) UnregisterNodeFromRegistry(backupNode BackupNode) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node.ID.String(), nodeInfoKeySuffix)
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, backupNode.ID.String(), nodeInfoKeySuffix)
counterKey := fmt.Sprintf(
"%s%s%s",
nodeActiveTasksPrefix,
node.ID.String(),
nodeActiveTasksSuffix,
nodeActiveBackupsPrefix,
backupNode.ID.String(),
nodeActiveBackupsSuffix,
)
result := r.client.Do(
@@ -331,49 +328,49 @@ func (r *TaskNodesRegistry) UnregisterNodeFromRegistry(node TaskNode) error {
)
if result.Error() != nil {
return fmt.Errorf("failed to unregister node %s: %w", node.ID, result.Error())
return fmt.Errorf("failed to unregister node %s: %w", backupNode.ID, result.Error())
}
r.logger.Info("Unregistered node from registry", "nodeID", node.ID)
r.logger.Info("Unregistered node from registry", "nodeID", backupNode.ID)
return nil
}
func (r *TaskNodesRegistry) AssignTaskToNode(
targetNodeID string,
taskID uuid.UUID,
func (r *BackupNodesRegistry) AssignBackupToNode(
targetNodeID uuid.UUID,
backupID uuid.UUID,
isCallNotifier bool,
) error {
ctx := context.Background()
message := TaskSubmitMessage{
message := BackupSubmitMessage{
NodeID: targetNodeID,
TaskID: taskID.String(),
BackupID: backupID,
IsCallNotifier: isCallNotifier,
}
messageJSON, err := json.Marshal(message)
if err != nil {
return fmt.Errorf("failed to marshal task submit message: %w", err)
return fmt.Errorf("failed to marshal backup submit message: %w", err)
}
err = r.pubsubTasks.Publish(ctx, taskSubmitChannel, string(messageJSON))
err = r.pubsubBackups.Publish(ctx, backupSubmitChannel, string(messageJSON))
if err != nil {
return fmt.Errorf("failed to publish task submit message: %w", err)
return fmt.Errorf("failed to publish backup submit message: %w", err)
}
return nil
}
func (r *TaskNodesRegistry) SubscribeNodeForTasksAssignment(
nodeID string,
handler func(taskID uuid.UUID, isCallNotifier bool),
func (r *BackupNodesRegistry) SubscribeNodeForBackupsAssignment(
nodeID uuid.UUID,
handler func(backupID uuid.UUID, isCallNotifier bool),
) error {
ctx := context.Background()
wrappedHandler := func(message string) {
var msg TaskSubmitMessage
var msg BackupSubmitMessage
if err := json.Unmarshal([]byte(message), &msg); err != nil {
r.logger.Warn("Failed to unmarshal task submit message", "error", err)
r.logger.Warn("Failed to unmarshal backup submit message", "error", err)
return
}
@@ -381,108 +378,84 @@ func (r *TaskNodesRegistry) SubscribeNodeForTasksAssignment(
return
}
taskID, err := uuid.Parse(msg.TaskID)
if err != nil {
r.logger.Warn(
"Failed to parse task ID from message",
"taskId",
msg.TaskID,
"error",
err,
)
return
}
handler(taskID, msg.IsCallNotifier)
handler(msg.BackupID, msg.IsCallNotifier)
}
err := r.pubsubTasks.Subscribe(ctx, taskSubmitChannel, wrappedHandler)
err := r.pubsubBackups.Subscribe(ctx, backupSubmitChannel, wrappedHandler)
if err != nil {
return fmt.Errorf("failed to subscribe to task submit channel: %w", err)
return fmt.Errorf("failed to subscribe to backup submit channel: %w", err)
}
r.logger.Info("Subscribed to task submit channel", "nodeID", nodeID)
r.logger.Info("Subscribed to backup submit channel", "nodeID", nodeID)
return nil
}
func (r *TaskNodesRegistry) UnsubscribeNodeForTasksAssignments() error {
err := r.pubsubTasks.Close()
func (r *BackupNodesRegistry) UnsubscribeNodeForBackupsAssignments() error {
err := r.pubsubBackups.Close()
if err != nil {
return fmt.Errorf("failed to unsubscribe from task submit channel: %w", err)
return fmt.Errorf("failed to unsubscribe from backup submit channel: %w", err)
}
r.logger.Info("Unsubscribed from task submit channel")
r.logger.Info("Unsubscribed from backup submit channel")
return nil
}
func (r *TaskNodesRegistry) PublishTaskCompletion(nodeID string, taskID uuid.UUID) error {
func (r *BackupNodesRegistry) PublishBackupCompletion(nodeID uuid.UUID, backupID uuid.UUID) error {
ctx := context.Background()
message := TaskCompletionMessage{
NodeID: nodeID,
TaskID: taskID.String(),
message := BackupCompletionMessage{
NodeID: nodeID,
BackupID: backupID,
}
messageJSON, err := json.Marshal(message)
if err != nil {
return fmt.Errorf("failed to marshal task completion message: %w", err)
return fmt.Errorf("failed to marshal backup completion message: %w", err)
}
err = r.pubsubCompletions.Publish(ctx, taskCompletionChannel, string(messageJSON))
err = r.pubsubCompletions.Publish(ctx, backupCompletionChannel, string(messageJSON))
if err != nil {
return fmt.Errorf("failed to publish task completion message: %w", err)
return fmt.Errorf("failed to publish backup completion message: %w", err)
}
return nil
}
func (r *TaskNodesRegistry) SubscribeForTasksCompletions(
handler func(nodeID string, taskID uuid.UUID),
func (r *BackupNodesRegistry) SubscribeForBackupsCompletions(
handler func(nodeID uuid.UUID, backupID uuid.UUID),
) error {
ctx := context.Background()
wrappedHandler := func(message string) {
var msg TaskCompletionMessage
var msg BackupCompletionMessage
if err := json.Unmarshal([]byte(message), &msg); err != nil {
r.logger.Warn("Failed to unmarshal task completion message", "error", err)
r.logger.Warn("Failed to unmarshal backup completion message", "error", err)
return
}
taskID, err := uuid.Parse(msg.TaskID)
if err != nil {
r.logger.Warn(
"Failed to parse task ID from completion message",
"taskId",
msg.TaskID,
"error",
err,
)
return
}
handler(msg.NodeID, taskID)
handler(msg.NodeID, msg.BackupID)
}
err := r.pubsubCompletions.Subscribe(ctx, taskCompletionChannel, wrappedHandler)
err := r.pubsubCompletions.Subscribe(ctx, backupCompletionChannel, wrappedHandler)
if err != nil {
return fmt.Errorf("failed to subscribe to task completion channel: %w", err)
return fmt.Errorf("failed to subscribe to backup completion channel: %w", err)
}
r.logger.Info("Subscribed to task completion channel")
r.logger.Info("Subscribed to backup completion channel")
return nil
}
func (r *TaskNodesRegistry) UnsubscribeForTasksCompletions() error {
func (r *BackupNodesRegistry) UnsubscribeForBackupsCompletions() error {
err := r.pubsubCompletions.Close()
if err != nil {
return fmt.Errorf("failed to unsubscribe from task completion channel: %w", err)
return fmt.Errorf("failed to unsubscribe from backup completion channel: %w", err)
}
r.logger.Info("Unsubscribed from task completion channel")
r.logger.Info("Unsubscribed from backup completion channel")
return nil
}
func (r *TaskNodesRegistry) extractNodeIDFromKey(key, prefix, suffix string) uuid.UUID {
func (r *BackupNodesRegistry) extractNodeIDFromKey(key, prefix, suffix string) uuid.UUID {
nodeIDStr := strings.TrimPrefix(key, prefix)
nodeIDStr = strings.TrimSuffix(nodeIDStr, suffix)
@@ -495,7 +468,7 @@ func (r *TaskNodesRegistry) extractNodeIDFromKey(key, prefix, suffix string) uui
return nodeID
}
func (r *TaskNodesRegistry) pipelineGetKeys(keys []string) (map[string][]byte, error) {
func (r *BackupNodesRegistry) pipelineGetKeys(keys []string) (map[string][]byte, error) {
if len(keys) == 0 {
return make(map[string][]byte), nil
}
@@ -529,7 +502,7 @@ func (r *TaskNodesRegistry) pipelineGetKeys(keys []string) (map[string][]byte, e
return keyDataMap, nil
}
func (r *TaskNodesRegistry) parseIntFromBytes(data []byte) (int64, error) {
func (r *BackupNodesRegistry) parseIntFromBytes(data []byte) (int64, error) {
str := string(data)
var count int64
_, err := fmt.Sscanf(str, "%d", &count)
@@ -539,7 +512,7 @@ func (r *TaskNodesRegistry) parseIntFromBytes(data []byte) (int64, error) {
return count, nil
}
func (r *TaskNodesRegistry) cleanupDeadNodes() error {
func (r *BackupNodesRegistry) cleanupDeadNodes() error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
@@ -583,13 +556,12 @@ func (r *TaskNodesRegistry) cleanupDeadNodes() error {
var deadNodeKeys []string
for key, data := range keyDataMap {
// Skip if the key doesn't exist (data is empty)
if len(data) == 0 {
continue
}
var node TaskNode
var node BackupNode
if err := json.Unmarshal(data, &node); err != nil {
r.logger.Warn("Failed to unmarshal node data during cleanup", "key", key, "error", err)
continue
@@ -603,7 +575,12 @@ func (r *TaskNodesRegistry) cleanupDeadNodes() error {
if node.LastHeartbeat.Before(threshold) {
nodeID := node.ID.String()
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeID, nodeInfoKeySuffix)
statsKey := fmt.Sprintf("%s%s%s", nodeActiveTasksPrefix, nodeID, nodeActiveTasksSuffix)
statsKey := fmt.Sprintf(
"%s%s%s",
nodeActiveBackupsPrefix,
nodeID,
nodeActiveBackupsSuffix,
)
deadNodeKeys = append(deadNodeKeys, infoKey, statsKey)
r.logger.Info(

File diff suppressed because it is too large Load Diff

View File

@@ -7,7 +7,6 @@ import (
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/storages"
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
task_registry "databasus-backend/internal/features/tasks/registry"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/period"
"fmt"
@@ -28,7 +27,7 @@ type BackupsScheduler struct {
backupConfigService *backups_config.BackupConfigService
storageService *storages.StorageService
taskCancelManager *task_cancellation.TaskCancelManager
tasksRegistry *task_registry.TaskNodesRegistry
backupNodesRegistry *BackupNodesRegistry
lastBackupTime time.Time
logger *slog.Logger
@@ -50,12 +49,14 @@ func (s *BackupsScheduler) Run(ctx context.Context) {
panic(err)
}
if err := s.tasksRegistry.SubscribeForTasksCompletions(s.onBackupCompleted); err != nil {
err := s.backupNodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted)
if err != nil {
s.logger.Error("Failed to subscribe to backup completions", "error", err)
panic(err)
}
defer func() {
if err := s.tasksRegistry.UnsubscribeForTasksCompletions(); err != nil {
if err := s.backupNodesRegistry.UnsubscribeForBackupsCompletions(); err != nil {
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
}
}()
@@ -180,7 +181,7 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
return
}
if err := s.tasksRegistry.IncrementTasksInProgress(leastBusyNodeID.String()); err != nil {
if err := s.backupNodesRegistry.IncrementBackupsInProgress(*leastBusyNodeID); err != nil {
s.logger.Error(
"Failed to increment backups in progress",
"nodeId",
@@ -193,7 +194,7 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
return
}
if err := s.tasksRegistry.AssignTaskToNode(leastBusyNodeID.String(), backup.ID, isCallNotifier); err != nil {
if err := s.backupNodesRegistry.AssignBackupToNode(*leastBusyNodeID, backup.ID, isCallNotifier); err != nil {
s.logger.Error(
"Failed to submit backup",
"nodeId",
@@ -203,7 +204,7 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
"error",
err,
)
if decrementErr := s.tasksRegistry.DecrementTasksInProgress(leastBusyNodeID.String()); decrementErr != nil {
if decrementErr := s.backupNodesRegistry.DecrementBackupsInProgress(*leastBusyNodeID); decrementErr != nil {
s.logger.Error(
"Failed to decrement backups in progress after submit failure",
"nodeId",
@@ -398,7 +399,7 @@ func (s *BackupsScheduler) runPendingBackups() error {
}
func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
nodes, err := s.tasksRegistry.GetAvailableNodes()
nodes, err := s.backupNodesRegistry.GetAvailableNodes()
if err != nil {
return nil, fmt.Errorf("failed to get available nodes: %w", err)
}
@@ -407,17 +408,17 @@ func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
return nil, fmt.Errorf("no nodes available")
}
stats, err := s.tasksRegistry.GetNodesStats()
stats, err := s.backupNodesRegistry.GetBackupNodesStats()
if err != nil {
return nil, fmt.Errorf("failed to get backup nodes stats: %w", err)
}
statsMap := make(map[uuid.UUID]int)
for _, stat := range stats {
statsMap[stat.ID] = stat.ActiveTasks
statsMap[stat.ID] = stat.ActiveBackups
}
var bestNode *task_registry.TaskNode
var bestNode *BackupNode
var bestScore float64 = -1
for i := range nodes {
@@ -445,21 +446,9 @@ func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
return &bestNode.ID, nil
}
func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUID) {
nodeID, err := uuid.Parse(nodeIDStr)
if err != nil {
s.logger.Error(
"Failed to parse node ID from completion message",
"nodeId",
nodeIDStr,
"error",
err,
)
return
}
func (s *BackupsScheduler) onBackupCompleted(nodeID uuid.UUID, backupID uuid.UUID) {
// Verify this task is actually a backup (registry contains multiple task types)
_, err = s.backupRepository.FindByID(backupID)
_, err := s.backupRepository.FindByID(backupID)
if err != nil {
// Not a backup task, ignore it
return
@@ -505,7 +494,7 @@ func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUI
s.backupToNodeRelations[nodeID] = relation
}
if err := s.tasksRegistry.DecrementTasksInProgress(nodeIDStr); err != nil {
if err := s.backupNodesRegistry.DecrementBackupsInProgress(nodeID); err != nil {
s.logger.Error(
"Failed to decrement backups in progress",
"nodeId",
@@ -519,7 +508,7 @@ func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUI
}
func (s *BackupsScheduler) checkDeadNodesAndFailBackups() error {
nodes, err := s.tasksRegistry.GetAvailableNodes()
nodes, err := s.backupNodesRegistry.GetAvailableNodes()
if err != nil {
return fmt.Errorf("failed to get available nodes: %w", err)
}
@@ -575,7 +564,7 @@ func (s *BackupsScheduler) checkDeadNodesAndFailBackups() error {
continue
}
if err := s.tasksRegistry.DecrementTasksInProgress(nodeID.String()); err != nil {
if err := s.backupNodesRegistry.DecrementBackupsInProgress(nodeID); err != nil {
s.logger.Error(
"Failed to decrement backups in progress for dead node",
"nodeId",

View File

@@ -7,7 +7,6 @@ import (
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
task_registry "databasus-backend/internal/features/tasks/registry"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
@@ -466,7 +465,7 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
// Clean up mock node
if mockNodeID != uuid.Nil {
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: mockNodeID})
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: mockNodeID})
}
cache_utils.ClearAllCache()
}()
@@ -502,12 +501,12 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status)
// Verify Valkey counter was incremented when backup was assigned
stats, err := nodesRegistry.GetNodesStats()
stats, err := backupNodesRegistry.GetBackupNodesStats()
assert.NoError(t, err)
foundStat := false
for _, stat := range stats {
if stat.ID == mockNodeID {
assert.Equal(t, 1, stat.ActiveTasks)
assert.Equal(t, 1, stat.ActiveBackups)
foundStat = true
break
}
@@ -532,11 +531,11 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
assert.Contains(t, *backups[0].FailMessage, "node unavailability")
// Verify Valkey counter was decremented after backup failed
stats, err = nodesRegistry.GetNodesStats()
stats, err = backupNodesRegistry.GetBackupNodesStats()
assert.NoError(t, err)
for _, stat := range stats {
if stat.ID == mockNodeID {
assert.Equal(t, 0, stat.ActiveTasks)
assert.Equal(t, 0, stat.ActiveBackups)
}
}
@@ -569,7 +568,7 @@ func Test_OnBackupCompleted_WhenTaskIsNotBackup_SkipsProcessing(t *testing.T) {
// Clean up mock node
if mockNodeID != uuid.Nil {
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: mockNodeID})
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: mockNodeID})
}
cache_utils.ClearAllCache()
}()
@@ -605,12 +604,12 @@ func Test_OnBackupCompleted_WhenTaskIsNotBackup_SkipsProcessing(t *testing.T) {
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status)
// Get initial state of the registry
initialStats, err := nodesRegistry.GetNodesStats()
initialStats, err := backupNodesRegistry.GetBackupNodesStats()
assert.NoError(t, err)
var initialActiveTasks int
for _, stat := range initialStats {
if stat.ID == mockNodeID {
initialActiveTasks = stat.ActiveTasks
initialActiveTasks = stat.ActiveBackups
break
}
}
@@ -618,16 +617,16 @@ func Test_OnBackupCompleted_WhenTaskIsNotBackup_SkipsProcessing(t *testing.T) {
// Call onBackupCompleted with a random UUID (not a backup ID)
nonBackupTaskID := uuid.New()
GetBackupsScheduler().onBackupCompleted(mockNodeID.String(), nonBackupTaskID)
GetBackupsScheduler().onBackupCompleted(mockNodeID, nonBackupTaskID)
time.Sleep(100 * time.Millisecond)
// Verify: Active tasks counter should remain the same (not decremented)
stats, err := nodesRegistry.GetNodesStats()
stats, err := backupNodesRegistry.GetBackupNodesStats()
assert.NoError(t, err)
for _, stat := range stats {
if stat.ID == mockNodeID {
assert.Equal(t, initialActiveTasks, stat.ActiveTasks,
assert.Equal(t, initialActiveTasks, stat.ActiveBackups,
"Active tasks should not change for non-backup task")
}
}
@@ -658,9 +657,9 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
defer func() {
// Clean up all mock nodes
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: node1ID})
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: node2ID})
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: node3ID})
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node1ID})
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node2ID})
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node3ID})
cache_utils.ClearAllCache()
}()
@@ -672,17 +671,17 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
assert.NoError(t, err)
for range 5 {
err = nodesRegistry.IncrementTasksInProgress(node1ID.String())
err = backupNodesRegistry.IncrementBackupsInProgress(node1ID)
assert.NoError(t, err)
}
for range 2 {
err = nodesRegistry.IncrementTasksInProgress(node2ID.String())
err = backupNodesRegistry.IncrementBackupsInProgress(node2ID)
assert.NoError(t, err)
}
for range 8 {
err = nodesRegistry.IncrementTasksInProgress(node3ID.String())
err = backupNodesRegistry.IncrementBackupsInProgress(node3ID)
assert.NoError(t, err)
}
@@ -701,8 +700,8 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
defer func() {
// Clean up all mock nodes
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: node100MBsID})
nodesRegistry.UnregisterNodeFromRegistry(task_registry.TaskNode{ID: node50MBsID})
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node100MBsID})
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node50MBsID})
cache_utils.ClearAllCache()
}()
@@ -712,11 +711,11 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
assert.NoError(t, err)
for range 10 {
err = nodesRegistry.IncrementTasksInProgress(node100MBsID.String())
err = backupNodesRegistry.IncrementBackupsInProgress(node100MBsID)
assert.NoError(t, err)
}
err = nodesRegistry.IncrementTasksInProgress(node50MBsID.String())
err = backupNodesRegistry.IncrementBackupsInProgress(node50MBsID)
assert.NoError(t, err)
leastBusyNodeID, err := GetBackupsScheduler().calculateLeastBusyNode()
@@ -880,12 +879,12 @@ func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T
assert.NoError(t, err)
// Get initial active task count
stats, err := nodesRegistry.GetNodesStats()
stats, err := backupNodesRegistry.GetBackupNodesStats()
assert.NoError(t, err)
var initialActiveTasks int
for _, stat := range stats {
if stat.ID == backuperNode.nodeID {
initialActiveTasks = stat.ActiveTasks
initialActiveTasks = stat.ActiveBackups
break
}
}
@@ -913,12 +912,12 @@ func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T
assert.True(t, decreased, "Active task count should have decreased after backup completion")
// Verify final active task count equals initial count
finalStats, err := nodesRegistry.GetNodesStats()
finalStats, err := backupNodesRegistry.GetBackupNodesStats()
assert.NoError(t, err)
for _, stat := range finalStats {
if stat.ID == backuperNode.nodeID {
t.Logf("Final active tasks: %d", stat.ActiveTasks)
assert.Equal(t, initialActiveTasks, stat.ActiveTasks,
t.Logf("Final active tasks: %d", stat.ActiveBackups)
assert.Equal(t, initialActiveTasks, stat.ActiveBackups,
"Active task count should return to initial value after backup completion")
break
}
@@ -982,12 +981,12 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
assert.NoError(t, err)
// Get initial active task count
stats, err := nodesRegistry.GetNodesStats()
stats, err := backupNodesRegistry.GetBackupNodesStats()
assert.NoError(t, err)
var initialActiveTasks int
for _, stat := range stats {
if stat.ID == backuperNode.nodeID {
initialActiveTasks = stat.ActiveTasks
initialActiveTasks = stat.ActiveBackups
break
}
}
@@ -1019,12 +1018,12 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
assert.True(t, decreased, "Active task count should have decreased after backup failure")
// Verify final active task count equals initial count
finalStats, err := nodesRegistry.GetNodesStats()
finalStats, err := backupNodesRegistry.GetBackupNodesStats()
assert.NoError(t, err)
for _, stat := range finalStats {
if stat.ID == backuperNode.nodeID {
t.Logf("Final active tasks: %d", stat.ActiveTasks)
assert.Equal(t, initialActiveTasks, stat.ActiveTasks,
t.Logf("Final active tasks: %d", stat.ActiveBackups)
assert.Equal(t, initialActiveTasks, stat.ActiveBackups,
"Active task count should return to initial value after backup failure")
break
}

View File

@@ -12,7 +12,6 @@ import (
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
task_registry "databasus-backend/internal/features/tasks/registry"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_services "databasus-backend/internal/features/workspaces/services"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
@@ -44,7 +43,7 @@ func CreateTestBackuperNode() *BackuperNode {
storages.GetStorageService(),
notifiers.GetNotifierService(),
taskCancelManager,
nodesRegistry,
backupNodesRegistry,
logger.GetLogger(),
usecases.GetCreateBackupUsecase(),
uuid.New(),
@@ -114,7 +113,7 @@ func StartBackuperNodeForTest(t *testing.T, backuperNode *BackuperNode) context.
// Poll registry for node presence instead of fixed sleep
deadline := time.Now().UTC().Add(5 * time.Second)
for time.Now().UTC().Before(deadline) {
nodes, err := nodesRegistry.GetAvailableNodes()
nodes, err := backupNodesRegistry.GetAvailableNodes()
if err == nil {
for _, node := range nodes {
if node.ID == backuperNode.nodeID {
@@ -175,7 +174,7 @@ func StopBackuperNodeForTest(t *testing.T, cancel context.CancelFunc, backuperNo
// Wait for node to unregister from registry
deadline := time.Now().UTC().Add(2 * time.Second)
for time.Now().UTC().Before(deadline) {
nodes, err := nodesRegistry.GetAvailableNodes()
nodes, err := backupNodesRegistry.GetAvailableNodes()
if err == nil {
found := false
for _, node := range nodes {
@@ -196,13 +195,13 @@ func StopBackuperNodeForTest(t *testing.T, cancel context.CancelFunc, backuperNo
}
func CreateMockNodeInRegistry(nodeID uuid.UUID, throughputMBs int, lastHeartbeat time.Time) error {
backupNode := task_registry.TaskNode{
backupNode := BackupNode{
ID: nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: lastHeartbeat,
}
return nodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
return backupNodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
}
func UpdateNodeHeartbeatDirectly(
@@ -210,17 +209,17 @@ func UpdateNodeHeartbeatDirectly(
throughputMBs int,
lastHeartbeat time.Time,
) error {
backupNode := task_registry.TaskNode{
backupNode := BackupNode{
ID: nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: lastHeartbeat,
}
return nodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
return backupNodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
}
func GetNodeFromRegistry(nodeID uuid.UUID) (*task_registry.TaskNode, error) {
nodes, err := nodesRegistry.GetAvailableNodes()
func GetNodeFromRegistry(nodeID uuid.UUID) (*BackupNode, error) {
nodes, err := backupNodesRegistry.GetAvailableNodes()
if err != nil {
return nil, err
}
@@ -246,7 +245,7 @@ func WaitForActiveTasksDecrease(
deadline := time.Now().UTC().Add(timeout)
for time.Now().UTC().Before(deadline) {
stats, err := nodesRegistry.GetNodesStats()
stats, err := backupNodesRegistry.GetBackupNodesStats()
if err != nil {
t.Logf("WaitForActiveTasksDecrease: error getting node stats: %v", err)
time.Sleep(500 * time.Millisecond)
@@ -257,14 +256,14 @@ func WaitForActiveTasksDecrease(
if stat.ID == nodeID {
t.Logf(
"WaitForActiveTasksDecrease: current active tasks = %d (initial = %d)",
stat.ActiveTasks,
stat.ActiveBackups,
initialCount,
)
if stat.ActiveTasks < initialCount {
if stat.ActiveBackups < initialCount {
t.Logf(
"WaitForActiveTasksDecrease: active tasks decreased from %d to %d",
initialCount,
stat.ActiveTasks,
stat.ActiveBackups,
)
return true
}

View File

@@ -75,3 +75,23 @@ func WaitForBackupCompletion(
t.Logf("WaitForBackupCompletion: timeout waiting for backup to complete")
}
// CreateTestBackup creates a simple test backup record for testing purposes
func CreateTestBackup(databaseID, storageID uuid.UUID) *backups_core.Backup {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: databaseID,
StorageID: storageID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10.5,
BackupDurationMs: 1000,
CreatedAt: time.Now().UTC(),
}
repo := &backups_core.BackupRepository{}
if err := repo.Save(backup); err != nil {
panic(err)
}
return backup
}

View File

@@ -1,38 +0,0 @@
package restores
import (
"context"
"databasus-backend/internal/features/restores/enums"
"log/slog"
)
type RestoreBackgroundService struct {
restoreRepository *RestoreRepository
logger *slog.Logger
}
func (s *RestoreBackgroundService) Run(ctx context.Context) {
if err := s.failRestoresInProgress(); err != nil {
s.logger.Error("Failed to fail restores in progress", "error", err)
panic(err)
}
}
func (s *RestoreBackgroundService) failRestoresInProgress() error {
restoresInProgress, err := s.restoreRepository.FindByStatus(enums.RestoreStatusInProgress)
if err != nil {
return err
}
for _, restore := range restoresInProgress {
failMessage := "Restore failed due to application restart"
restore.Status = enums.RestoreStatusFailed
restore.FailMessage = &failMessage
if err := s.restoreRepository.Save(restore); err != nil {
return err
}
}
return nil
}

View File

@@ -1,6 +1,7 @@
package restores
import (
restores_core "databasus-backend/internal/features/restores/core"
users_middleware "databasus-backend/internal/features/users/middleware"
"net/http"
@@ -23,7 +24,7 @@ func (c *RestoreController) RegisterRoutes(router *gin.RouterGroup) {
// @Tags restores
// @Produce json
// @Param backupId path string true "Backup ID"
// @Success 200 {array} models.Restore
// @Success 200 {array} restores_core.Restore
// @Failure 400
// @Failure 401
// @Router /restores/{backupId} [get]
@@ -71,7 +72,7 @@ func (c *RestoreController) RestoreBackup(ctx *gin.Context) {
return
}
var requestDTO RestoreBackupRequest
var requestDTO restores_core.RestoreBackupRequest
if err := ctx.ShouldBindJSON(&requestDTO); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return

View File

@@ -18,20 +18,18 @@ import (
"databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/mysql"
"databasus-backend/internal/features/databases/databases/postgresql"
"databasus-backend/internal/features/restores/models"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
local_storage "databasus-backend/internal/features/storages/models/local"
users_dto "databasus-backend/internal/features/users/dto"
users_enums "databasus-backend/internal/features/users/enums"
users_services "databasus-backend/internal/features/users/services"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_models "databasus-backend/internal/features/workspaces/models"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
util_encryption "databasus-backend/internal/util/encryption"
@@ -46,7 +44,7 @@ func Test_GetRestores_WhenUserIsWorkspaceMember_RestoresReturned(t *testing.T) {
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
var restores []*models.Restore
var restores []*restores_core.Restore
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
@@ -90,7 +88,7 @@ func Test_GetRestores_WhenUserIsGlobalAdmin_RestoresReturned(t *testing.T) {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
var restores []*models.Restore
var restores []*restores_core.Restore
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
@@ -105,12 +103,16 @@ func Test_GetRestores_WhenUserIsGlobalAdmin_RestoresReturned(t *testing.T) {
func Test_RestoreBackup_WhenUserIsWorkspaceMember_RestoreInitiated(t *testing.T) {
router := createTestRouter()
_, cleanup := SetupMockRestoreNode(t)
defer cleanup()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
request := RestoreBackupRequest{
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
@@ -141,7 +143,7 @@ func Test_RestoreBackup_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
request := RestoreBackupRequest{
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
@@ -165,12 +167,16 @@ func Test_RestoreBackup_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing
func Test_RestoreBackup_WithIsExcludeExtensions_FlagPassedCorrectly(t *testing.T) {
router := createTestRouter()
_, cleanup := SetupMockRestoreNode(t)
defer cleanup()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
request := RestoreBackupRequest{
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
@@ -195,12 +201,16 @@ func Test_RestoreBackup_WithIsExcludeExtensions_FlagPassedCorrectly(t *testing.T
func Test_RestoreBackup_AuditLogWritten(t *testing.T) {
router := createTestRouter()
_, cleanup := SetupMockRestoreNode(t)
defer cleanup()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
request := RestoreBackupRequest{
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
@@ -272,15 +282,22 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
router := createTestRouter()
// Setup mock node for tests that skip disk validation and reach scheduler
if !tc.expectDiskValidated {
_, cleanup := SetupMockRestoreNode(t)
defer cleanup()
}
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
var backup *backups_core.Backup
var request RestoreBackupRequest
var request restores_core.RestoreBackupRequest
if tc.dbType == databases.DatabaseTypePostgres {
_, backup = createTestDatabaseWithBackupForRestore(workspace, owner, router)
request = RestoreBackupRequest{
request = restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
@@ -310,7 +327,7 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
assert.NoError(t, err)
backup = createTestBackup(mysqlDB, owner)
request = RestoreBackupRequest{
request = restores_core.RestoreBackupRequest{
MysqlDatabase: &mysql.MysqlDatabase{
Version: tools.MysqlVersion80,
Host: "localhost",
@@ -354,15 +371,7 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
}
func createTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
backups_config.GetBackupConfigController(),
backups.GetBackupController(),
GetRestoreController(),
)
return router
return CreateTestRouter()
}
func createTestDatabaseWithBackupForRestore(

View File

@@ -1,4 +1,4 @@
package restores
package restores_core
import (
"databasus-backend/internal/features/databases/databases/mariadb"

View File

@@ -1,4 +1,4 @@
package enums
package restores_core
type RestoreStatus string

View File

@@ -0,0 +1,20 @@
package restores_core
import (
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/storages"
)
type RestoreBackupUsecase interface {
Execute(
backupConfig *backups_config.BackupConfig,
restore Restore,
originalDB *databases.Database,
restoringToDB *databases.Database,
backup *backups_core.Backup,
storage *storages.Storage,
isExcludeExtensions bool,
) error
}

View File

@@ -0,0 +1,30 @@
package restores_core
import (
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/databases/databases/mariadb"
"databasus-backend/internal/features/databases/databases/mongodb"
"databasus-backend/internal/features/databases/databases/mysql"
"databasus-backend/internal/features/databases/databases/postgresql"
"time"
"github.com/google/uuid"
)
type Restore struct {
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
Status RestoreStatus `json:"status" gorm:"column:status;type:text;not null"`
BackupID uuid.UUID `json:"backupId" gorm:"column:backup_id;type:uuid;not null"`
Backup *backups_core.Backup
PostgresqlDatabase *postgresql.PostgresqlDatabase `json:"postgresqlDatabase" gorm:"-"`
MysqlDatabase *mysql.MysqlDatabase `json:"mysqlDatabase" gorm:"-"`
MariadbDatabase *mariadb.MariadbDatabase `json:"mariadbDatabase" gorm:"-"`
MongodbDatabase *mongodb.MongodbDatabase `json:"mongodbDatabase" gorm:"-"`
FailMessage *string `json:"failMessage" gorm:"column:fail_message"`
RestoreDurationMs int64 `json:"restoreDurationMs" gorm:"column:restore_duration_ms;default:0"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;default:now()"`
}

View File

@@ -1,8 +1,6 @@
package restores
package restores_core
import (
"databasus-backend/internal/features/restores/enums"
"databasus-backend/internal/features/restores/models"
"databasus-backend/internal/storage"
"github.com/google/uuid"
@@ -10,24 +8,24 @@ import (
type RestoreRepository struct{}
func (r *RestoreRepository) Save(restore *models.Restore) error {
func (r *RestoreRepository) Save(restore *Restore) error {
db := storage.GetDb()
isNew := restore.ID == uuid.Nil
if isNew {
restore.ID = uuid.New()
return db.Create(restore).
Omit("Backup").
Omit("Backup", "PostgresqlDatabase", "MysqlDatabase", "MariadbDatabase", "MongodbDatabase").
Error
}
return db.Save(restore).
Omit("Backup").
Omit("Backup", "PostgresqlDatabase", "MysqlDatabase", "MariadbDatabase", "MongodbDatabase").
Error
}
func (r *RestoreRepository) FindByBackupID(backupID uuid.UUID) ([]*models.Restore, error) {
var restores []*models.Restore
func (r *RestoreRepository) FindByBackupID(backupID uuid.UUID) ([]*Restore, error) {
var restores []*Restore
if err := storage.
GetDb().
@@ -41,8 +39,8 @@ func (r *RestoreRepository) FindByBackupID(backupID uuid.UUID) ([]*models.Restor
return restores, nil
}
func (r *RestoreRepository) FindByID(id uuid.UUID) (*models.Restore, error) {
var restore models.Restore
func (r *RestoreRepository) FindByID(id uuid.UUID) (*Restore, error) {
var restore Restore
if err := storage.
GetDb().
@@ -55,8 +53,8 @@ func (r *RestoreRepository) FindByID(id uuid.UUID) (*models.Restore, error) {
return &restore, nil
}
func (r *RestoreRepository) FindByStatus(status enums.RestoreStatus) ([]*models.Restore, error) {
var restores []*models.Restore
func (r *RestoreRepository) FindByStatus(status RestoreStatus) ([]*Restore, error) {
var restores []*Restore
if err := storage.
GetDb().
@@ -71,5 +69,5 @@ func (r *RestoreRepository) FindByStatus(status enums.RestoreStatus) ([]*models.
}
func (r *RestoreRepository) DeleteByID(id uuid.UUID) error {
return storage.GetDb().Delete(&models.Restore{}, "id = ?", id).Error
return storage.GetDb().Delete(&Restore{}, "id = ?", id).Error
}

View File

@@ -6,6 +6,7 @@ import (
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/disk"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/restores/usecases"
"databasus-backend/internal/features/storages"
workspaces_services "databasus-backend/internal/features/workspaces/services"
@@ -13,7 +14,7 @@ import (
"databasus-backend/internal/util/logger"
)
var restoreRepository = &RestoreRepository{}
var restoreRepository = &restores_core.RestoreRepository{}
var restoreService = &RestoreService{
backups.GetBackupService(),
restoreRepository,
@@ -31,19 +32,10 @@ var restoreController = &RestoreController{
restoreService,
}
var restoreBackgroundService = &RestoreBackgroundService{
restoreRepository,
logger.GetLogger(),
}
func GetRestoreController() *RestoreController {
return restoreController
}
func GetRestoreBackgroundService() *RestoreBackgroundService {
return restoreBackgroundService
}
func SetupDependencies() {
backups.GetBackupService().AddBackupRemoveListener(restoreService)
}

View File

@@ -1,22 +0,0 @@
package models
import (
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/restores/enums"
"time"
"github.com/google/uuid"
)
type Restore struct {
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
Status enums.RestoreStatus `json:"status" gorm:"column:status;type:text;not null"`
BackupID uuid.UUID `json:"backupId" gorm:"column:backup_id;type:uuid;not null"`
Backup *backups_core.Backup
FailMessage *string `json:"failMessage" gorm:"column:fail_message"`
RestoreDurationMs int64 `json:"restoreDurationMs" gorm:"column:restore_duration_ms;default:0"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;default:now()"`
}

View File

@@ -0,0 +1,73 @@
package restoring
import (
"time"
"github.com/google/uuid"
"databasus-backend/internal/features/backups/backups"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/restores/usecases"
"databasus-backend/internal/features/storages"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
)
var restoreRepository = &restores_core.RestoreRepository{}
var restoreNodesRegistry = &RestoreNodesRegistry{
cache_utils.GetValkeyClient(),
logger.GetLogger(),
cache_utils.DefaultCacheTimeout,
cache_utils.NewPubSubManager(),
cache_utils.NewPubSubManager(),
}
var restoreDatabaseCache = cache_utils.NewCacheUtil[RestoreDatabaseCache](
cache_utils.GetValkeyClient(),
"restore_db:",
)
var restorerNode = &RestorerNode{
uuid.New(),
databases.GetDatabaseService(),
backups.GetBackupService(),
encryption.GetFieldEncryptor(),
restoreRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
restoreNodesRegistry,
logger.GetLogger(),
usecases.GetRestoreBackupUsecase(),
restoreDatabaseCache,
time.Time{},
}
var restoresScheduler = &RestoresScheduler{
restoreRepository: restoreRepository,
backupService: backups.GetBackupService(),
storageService: storages.GetStorageService(),
backupConfigService: backups_config.GetBackupConfigService(),
restoreNodesRegistry: restoreNodesRegistry,
lastCheckTime: time.Now().UTC(),
logger: logger.GetLogger(),
restoreToNodeRelations: make(map[uuid.UUID]RestoreToNodeRelation),
restorerNode: restorerNode,
cacheUtil: restoreDatabaseCache,
completionSubscriptionID: uuid.Nil,
}
func GetRestoresScheduler() *RestoresScheduler {
return restoresScheduler
}
func GetRestorerNode() *RestorerNode {
return restorerNode
}
func GetRestoreNodesRegistry() *RestoreNodesRegistry {
return restoreNodesRegistry
}

View File

@@ -0,0 +1,45 @@
package restoring
import (
"databasus-backend/internal/features/databases/databases/mariadb"
"databasus-backend/internal/features/databases/databases/mongodb"
"databasus-backend/internal/features/databases/databases/mysql"
"databasus-backend/internal/features/databases/databases/postgresql"
"time"
"github.com/google/uuid"
)
type RestoreDatabaseCache struct {
PostgresqlDatabase *postgresql.PostgresqlDatabase `json:"postgresqlDatabase,omitempty"`
MysqlDatabase *mysql.MysqlDatabase `json:"mysqlDatabase,omitempty"`
MariadbDatabase *mariadb.MariadbDatabase `json:"mariadbDatabase,omitempty"`
MongodbDatabase *mongodb.MongodbDatabase `json:"mongodbDatabase,omitempty"`
}
type RestoreToNodeRelation struct {
NodeID uuid.UUID `json:"nodeId"`
RestoreIDs []uuid.UUID `json:"restoreIds"`
}
type RestoreNode struct {
ID uuid.UUID `json:"id"`
ThroughputMBs int `json:"throughputMBs"`
LastHeartbeat time.Time `json:"lastHeartbeat"`
}
type RestoreNodeStats struct {
ID uuid.UUID `json:"id"`
ActiveRestores int `json:"activeRestores"`
}
type RestoreSubmitMessage struct {
NodeID uuid.UUID `json:"nodeId"`
RestoreID uuid.UUID `json:"restoreId"`
IsCallNotifier bool `json:"isCallNotifier"`
}
type RestoreCompletionMessage struct {
NodeID uuid.UUID `json:"nodeId"`
RestoreID uuid.UUID `json:"restoreId"`
}

View File

@@ -0,0 +1,61 @@
package restoring
import (
"errors"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
)
type MockSuccessRestoreUsecase struct{}
func (uc *MockSuccessRestoreUsecase) Execute(
backupConfig *backups_config.BackupConfig,
restore restores_core.Restore,
originalDB *databases.Database,
restoringToDB *databases.Database,
backup *backups_core.Backup,
storage *storages.Storage,
isExcludeExtensions bool,
) error {
return nil
}
type MockFailedRestoreUsecase struct{}
func (uc *MockFailedRestoreUsecase) Execute(
backupConfig *backups_config.BackupConfig,
restore restores_core.Restore,
originalDB *databases.Database,
restoringToDB *databases.Database,
backup *backups_core.Backup,
storage *storages.Storage,
isExcludeExtensions bool,
) error {
return errors.New("restore failed")
}
type MockCaptureCredentialsRestoreUsecase struct {
CalledChan chan *databases.Database
ShouldSucceed bool
}
func (uc *MockCaptureCredentialsRestoreUsecase) Execute(
backupConfig *backups_config.BackupConfig,
restore restores_core.Restore,
originalDB *databases.Database,
restoringToDB *databases.Database,
backup *backups_core.Backup,
storage *storages.Storage,
isExcludeExtensions bool,
) error {
uc.CalledChan <- restoringToDB
if uc.ShouldSucceed {
return nil
}
return errors.New("mock restore failed")
}

View File

@@ -0,0 +1,634 @@
package restoring
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"time"
cache_utils "databasus-backend/internal/util/cache"
"github.com/google/uuid"
"github.com/valkey-io/valkey-go"
)
const (
nodeInfoKeyPrefix = "restore:node:"
nodeInfoKeySuffix = ":info"
nodeActiveRestoresPrefix = "restore:node:"
nodeActiveRestoresSuffix = ":active_restores"
restoreSubmitChannel = "restore:submit"
restoreCompletionChannel = "restore:completion"
deadNodeThreshold = 2 * time.Minute
cleanupTickerInterval = 1 * time.Second
)
// RestoreNodesRegistry helps to sync restores scheduler and restore nodes.
//
// Features:
// - Track node availability and load level
// - Assign from scheduler to node restores needed to be processed
// - Notify scheduler from node about restore completion
//
// Important things to remember:
// - Nodes without heartbeat for more than 2 minutes are not included
// in available nodes list and stats
//
// Cleanup dead nodes performed on 2 levels:
// - List and stats functions do not return dead nodes
// - Periodically dead nodes are cleaned up in cache (to not
// accumulate too many dead nodes in cache)
type RestoreNodesRegistry struct {
client valkey.Client
logger *slog.Logger
timeout time.Duration
pubsubRestores *cache_utils.PubSubManager
pubsubCompletions *cache_utils.PubSubManager
}
func (r *RestoreNodesRegistry) Run(ctx context.Context) {
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
}
ticker := time.NewTicker(cleanupTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes", "error", err)
}
}
}
}
func (r *RestoreNodesRegistry) GetAvailableNodes() ([]RestoreNode, error) {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
var allKeys []string
cursor := uint64(0)
pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix
for {
result := r.client.Do(
ctx,
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(),
)
if result.Error() != nil {
return nil, fmt.Errorf("failed to scan node keys: %w", result.Error())
}
scanResult, err := result.AsScanEntry()
if err != nil {
return nil, fmt.Errorf("failed to parse scan result: %w", err)
}
allKeys = append(allKeys, scanResult.Elements...)
cursor = scanResult.Cursor
if cursor == 0 {
break
}
}
if len(allKeys) == 0 {
return []RestoreNode{}, nil
}
keyDataMap, err := r.pipelineGetKeys(allKeys)
if err != nil {
return nil, fmt.Errorf("failed to pipeline get node keys: %w", err)
}
threshold := time.Now().UTC().Add(-deadNodeThreshold)
var nodes []RestoreNode
for key, data := range keyDataMap {
// Skip if the key doesn't exist (data is empty)
if len(data) == 0 {
continue
}
var node RestoreNode
if err := json.Unmarshal(data, &node); err != nil {
r.logger.Warn("Failed to unmarshal node data", "key", key, "error", err)
continue
}
// Skip nodes with zero/uninitialized heartbeat
if node.LastHeartbeat.IsZero() {
continue
}
if node.LastHeartbeat.Before(threshold) {
continue
}
nodes = append(nodes, node)
}
return nodes, nil
}
func (r *RestoreNodesRegistry) GetRestoreNodesStats() ([]RestoreNodeStats, error) {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
var allKeys []string
cursor := uint64(0)
pattern := nodeActiveRestoresPrefix + "*" + nodeActiveRestoresSuffix
for {
result := r.client.Do(
ctx,
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(100).Build(),
)
if result.Error() != nil {
return nil, fmt.Errorf("failed to scan active restores keys: %w", result.Error())
}
scanResult, err := result.AsScanEntry()
if err != nil {
return nil, fmt.Errorf("failed to parse scan result: %w", err)
}
allKeys = append(allKeys, scanResult.Elements...)
cursor = scanResult.Cursor
if cursor == 0 {
break
}
}
if len(allKeys) == 0 {
return []RestoreNodeStats{}, nil
}
keyDataMap, err := r.pipelineGetKeys(allKeys)
if err != nil {
return nil, fmt.Errorf("failed to pipeline get active restores keys: %w", err)
}
var nodeInfoKeys []string
nodeIDToStatsKey := make(map[string]string)
for key := range keyDataMap {
nodeID := r.extractNodeIDFromKey(key, nodeActiveRestoresPrefix, nodeActiveRestoresSuffix)
nodeIDStr := nodeID.String()
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeIDStr, nodeInfoKeySuffix)
nodeInfoKeys = append(nodeInfoKeys, infoKey)
nodeIDToStatsKey[infoKey] = key
}
nodeInfoMap, err := r.pipelineGetKeys(nodeInfoKeys)
if err != nil {
return nil, fmt.Errorf("failed to pipeline get node info keys: %w", err)
}
threshold := time.Now().UTC().Add(-deadNodeThreshold)
var stats []RestoreNodeStats
for infoKey, nodeData := range nodeInfoMap {
// Skip if the info key doesn't exist (nodeData is empty)
if len(nodeData) == 0 {
continue
}
var node RestoreNode
if err := json.Unmarshal(nodeData, &node); err != nil {
r.logger.Warn("Failed to unmarshal node data", "key", infoKey, "error", err)
continue
}
// Skip nodes with zero/uninitialized heartbeat
if node.LastHeartbeat.IsZero() {
continue
}
if node.LastHeartbeat.Before(threshold) {
continue
}
statsKey := nodeIDToStatsKey[infoKey]
tasksData := keyDataMap[statsKey]
count, err := r.parseIntFromBytes(tasksData)
if err != nil {
r.logger.Warn("Failed to parse active restores count", "key", statsKey, "error", err)
continue
}
stat := RestoreNodeStats{
ID: node.ID,
ActiveRestores: int(count),
}
stats = append(stats, stat)
}
return stats, nil
}
func (r *RestoreNodesRegistry) IncrementRestoresInProgress(nodeID uuid.UUID) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
key := fmt.Sprintf(
"%s%s%s",
nodeActiveRestoresPrefix,
nodeID.String(),
nodeActiveRestoresSuffix,
)
result := r.client.Do(ctx, r.client.B().Incr().Key(key).Build())
if result.Error() != nil {
return fmt.Errorf(
"failed to increment restores in progress for node %s: %w",
nodeID,
result.Error(),
)
}
return nil
}
func (r *RestoreNodesRegistry) DecrementRestoresInProgress(nodeID uuid.UUID) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
key := fmt.Sprintf(
"%s%s%s",
nodeActiveRestoresPrefix,
nodeID.String(),
nodeActiveRestoresSuffix,
)
result := r.client.Do(ctx, r.client.B().Decr().Key(key).Build())
if result.Error() != nil {
return fmt.Errorf(
"failed to decrement restores in progress for node %s: %w",
nodeID,
result.Error(),
)
}
newValue, err := result.AsInt64()
if err != nil {
return fmt.Errorf("failed to parse decremented value for node %s: %w", nodeID, err)
}
if newValue < 0 {
setCtx, setCancel := context.WithTimeout(context.Background(), r.timeout)
r.client.Do(setCtx, r.client.B().Set().Key(key).Value("0").Build())
setCancel()
r.logger.Warn("Active restores counter went below 0, reset to 0", "nodeID", nodeID)
}
return nil
}
func (r *RestoreNodesRegistry) HearthbeatNodeInRegistry(
now time.Time,
restoreNode RestoreNode,
) error {
if now.IsZero() {
return fmt.Errorf("cannot register node with zero heartbeat timestamp")
}
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
restoreNode.LastHeartbeat = now
data, err := json.Marshal(restoreNode)
if err != nil {
return fmt.Errorf("failed to marshal restore node: %w", err)
}
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, restoreNode.ID.String(), nodeInfoKeySuffix)
result := r.client.Do(
ctx,
r.client.B().Set().Key(key).Value(string(data)).Build(),
)
if result.Error() != nil {
return fmt.Errorf("failed to register node %s: %w", restoreNode.ID, result.Error())
}
return nil
}
func (r *RestoreNodesRegistry) UnregisterNodeFromRegistry(restoreNode RestoreNode) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, restoreNode.ID.String(), nodeInfoKeySuffix)
counterKey := fmt.Sprintf(
"%s%s%s",
nodeActiveRestoresPrefix,
restoreNode.ID.String(),
nodeActiveRestoresSuffix,
)
result := r.client.Do(
ctx,
r.client.B().Del().Key(infoKey, counterKey).Build(),
)
if result.Error() != nil {
return fmt.Errorf("failed to unregister node %s: %w", restoreNode.ID, result.Error())
}
r.logger.Info("Unregistered node from registry", "nodeID", restoreNode.ID)
return nil
}
func (r *RestoreNodesRegistry) AssignRestoreToNode(
targetNodeID uuid.UUID,
restoreID uuid.UUID,
isCallNotifier bool,
) error {
ctx := context.Background()
message := RestoreSubmitMessage{
NodeID: targetNodeID,
RestoreID: restoreID,
IsCallNotifier: isCallNotifier,
}
messageJSON, err := json.Marshal(message)
if err != nil {
return fmt.Errorf("failed to marshal restore submit message: %w", err)
}
err = r.pubsubRestores.Publish(ctx, restoreSubmitChannel, string(messageJSON))
if err != nil {
return fmt.Errorf("failed to publish restore submit message: %w", err)
}
return nil
}
func (r *RestoreNodesRegistry) SubscribeNodeForRestoresAssignment(
nodeID uuid.UUID,
handler func(restoreID uuid.UUID, isCallNotifier bool),
) error {
ctx := context.Background()
wrappedHandler := func(message string) {
var msg RestoreSubmitMessage
if err := json.Unmarshal([]byte(message), &msg); err != nil {
r.logger.Warn("Failed to unmarshal restore submit message", "error", err)
return
}
if msg.NodeID != nodeID {
return
}
handler(msg.RestoreID, msg.IsCallNotifier)
}
err := r.pubsubRestores.Subscribe(ctx, restoreSubmitChannel, wrappedHandler)
if err != nil {
return fmt.Errorf("failed to subscribe to restore submit channel: %w", err)
}
r.logger.Info("Subscribed to restore submit channel", "nodeID", nodeID)
return nil
}
func (r *RestoreNodesRegistry) UnsubscribeNodeForRestoresAssignments() error {
err := r.pubsubRestores.Close()
if err != nil {
return fmt.Errorf("failed to unsubscribe from restore submit channel: %w", err)
}
r.logger.Info("Unsubscribed from restore submit channel")
return nil
}
func (r *RestoreNodesRegistry) PublishRestoreCompletion(
nodeID uuid.UUID,
restoreID uuid.UUID,
) error {
ctx := context.Background()
message := RestoreCompletionMessage{
NodeID: nodeID,
RestoreID: restoreID,
}
messageJSON, err := json.Marshal(message)
if err != nil {
return fmt.Errorf("failed to marshal restore completion message: %w", err)
}
err = r.pubsubCompletions.Publish(ctx, restoreCompletionChannel, string(messageJSON))
if err != nil {
return fmt.Errorf("failed to publish restore completion message: %w", err)
}
return nil
}
func (r *RestoreNodesRegistry) SubscribeForRestoresCompletions(
handler func(nodeID uuid.UUID, restoreID uuid.UUID),
) error {
ctx := context.Background()
wrappedHandler := func(message string) {
var msg RestoreCompletionMessage
if err := json.Unmarshal([]byte(message), &msg); err != nil {
r.logger.Warn("Failed to unmarshal restore completion message", "error", err)
return
}
handler(msg.NodeID, msg.RestoreID)
}
err := r.pubsubCompletions.Subscribe(ctx, restoreCompletionChannel, wrappedHandler)
if err != nil {
return fmt.Errorf("failed to subscribe to restore completion channel: %w", err)
}
r.logger.Info("Subscribed to restore completion channel")
return nil
}
func (r *RestoreNodesRegistry) UnsubscribeForRestoresCompletions() error {
err := r.pubsubCompletions.Close()
if err != nil {
return fmt.Errorf("failed to unsubscribe from restore completion channel: %w", err)
}
r.logger.Info("Unsubscribed from restore completion channel")
return nil
}
func (r *RestoreNodesRegistry) extractNodeIDFromKey(key, prefix, suffix string) uuid.UUID {
nodeIDStr := strings.TrimPrefix(key, prefix)
nodeIDStr = strings.TrimSuffix(nodeIDStr, suffix)
nodeID, err := uuid.Parse(nodeIDStr)
if err != nil {
r.logger.Warn("Failed to parse node ID from key", "key", key, "error", err)
return uuid.Nil
}
return nodeID
}
func (r *RestoreNodesRegistry) pipelineGetKeys(keys []string) (map[string][]byte, error) {
if len(keys) == 0 {
return make(map[string][]byte), nil
}
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
commands := make([]valkey.Completed, 0, len(keys))
for _, key := range keys {
commands = append(commands, r.client.B().Get().Key(key).Build())
}
results := r.client.DoMulti(ctx, commands...)
keyDataMap := make(map[string][]byte, len(keys))
for i, result := range results {
if result.Error() != nil {
r.logger.Warn("Failed to get key in pipeline", "key", keys[i], "error", result.Error())
continue
}
data, err := result.AsBytes()
if err != nil {
r.logger.Warn("Failed to parse key data in pipeline", "key", keys[i], "error", err)
continue
}
keyDataMap[keys[i]] = data
}
return keyDataMap, nil
}
func (r *RestoreNodesRegistry) parseIntFromBytes(data []byte) (int64, error) {
str := string(data)
var count int64
_, err := fmt.Sscanf(str, "%d", &count)
if err != nil {
return 0, fmt.Errorf("failed to parse integer from bytes: %w", err)
}
return count, nil
}
func (r *RestoreNodesRegistry) cleanupDeadNodes() error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
var allKeys []string
cursor := uint64(0)
pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix
for {
result := r.client.Do(
ctx,
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(),
)
if result.Error() != nil {
return fmt.Errorf("failed to scan node keys: %w", result.Error())
}
scanResult, err := result.AsScanEntry()
if err != nil {
return fmt.Errorf("failed to parse scan result: %w", err)
}
allKeys = append(allKeys, scanResult.Elements...)
cursor = scanResult.Cursor
if cursor == 0 {
break
}
}
if len(allKeys) == 0 {
return nil
}
keyDataMap, err := r.pipelineGetKeys(allKeys)
if err != nil {
return fmt.Errorf("failed to pipeline get node keys: %w", err)
}
threshold := time.Now().UTC().Add(-deadNodeThreshold)
var deadNodeKeys []string
for key, data := range keyDataMap {
// Skip if the key doesn't exist (data is empty)
if len(data) == 0 {
continue
}
var node RestoreNode
if err := json.Unmarshal(data, &node); err != nil {
r.logger.Warn("Failed to unmarshal node data during cleanup", "key", key, "error", err)
continue
}
// Skip nodes with zero/uninitialized heartbeat
if node.LastHeartbeat.IsZero() {
continue
}
if node.LastHeartbeat.Before(threshold) {
nodeID := node.ID.String()
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeID, nodeInfoKeySuffix)
statsKey := fmt.Sprintf(
"%s%s%s",
nodeActiveRestoresPrefix,
nodeID,
nodeActiveRestoresSuffix,
)
deadNodeKeys = append(deadNodeKeys, infoKey, statsKey)
r.logger.Info(
"Marking node for cleanup",
"nodeID", nodeID,
"lastHeartbeat", node.LastHeartbeat,
"threshold", threshold,
)
}
}
if len(deadNodeKeys) == 0 {
return nil
}
delCtx, delCancel := context.WithTimeout(context.Background(), r.timeout)
defer delCancel()
result := r.client.Do(
delCtx,
r.client.B().Del().Key(deadNodeKeys...).Build(),
)
if result.Error() != nil {
return fmt.Errorf("failed to delete dead node keys: %w", result.Error())
}
deletedCount, err := result.AsInt64()
if err != nil {
return fmt.Errorf("failed to parse deleted count: %w", err)
}
r.logger.Info("Cleaned up dead nodes", "deletedKeysCount", deletedCount)
return nil
}

View File

@@ -0,0 +1,262 @@
package restoring
import (
"context"
"fmt"
"log/slog"
"time"
"github.com/google/uuid"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
cache_utils "databasus-backend/internal/util/cache"
util_encryption "databasus-backend/internal/util/encryption"
)
const (
heartbeatTickerInterval = 15 * time.Second
restorerHealthcheckThreshold = 5 * time.Minute
)
type RestorerNode struct {
nodeID uuid.UUID
databaseService *databases.DatabaseService
backupService *backups.BackupService
fieldEncryptor util_encryption.FieldEncryptor
restoreRepository *restores_core.RestoreRepository
backupConfigService *backups_config.BackupConfigService
storageService *storages.StorageService
restoreNodesRegistry *RestoreNodesRegistry
logger *slog.Logger
restoreBackupUsecase restores_core.RestoreBackupUsecase
cacheUtil *cache_utils.CacheUtil[RestoreDatabaseCache]
lastHeartbeat time.Time
}
func (n *RestorerNode) Run(ctx context.Context) {
n.lastHeartbeat = time.Now().UTC()
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
restoreNode := RestoreNode{
ID: n.nodeID,
ThroughputMBs: throughputMBs,
}
if err := n.restoreNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), restoreNode); err != nil {
n.logger.Error("Failed to register node in registry", "error", err)
panic(err)
}
restoreHandler := func(restoreID uuid.UUID, isCallNotifier bool) {
n.MakeRestore(restoreID)
if err := n.restoreNodesRegistry.PublishRestoreCompletion(n.nodeID, restoreID); err != nil {
n.logger.Error(
"Failed to publish restore completion",
"error",
err,
"restoreID",
restoreID,
)
}
}
err := n.restoreNodesRegistry.SubscribeNodeForRestoresAssignment(
n.nodeID,
restoreHandler,
)
if err != nil {
n.logger.Error("Failed to subscribe to restore assignments", "error", err)
panic(err)
}
defer func() {
if err := n.restoreNodesRegistry.UnsubscribeNodeForRestoresAssignments(); err != nil {
n.logger.Error("Failed to unsubscribe from restore assignments", "error", err)
}
}()
ticker := time.NewTicker(heartbeatTickerInterval)
defer ticker.Stop()
n.logger.Info("Restore node started", "nodeID", n.nodeID, "throughput", throughputMBs)
for {
select {
case <-ctx.Done():
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
if err := n.restoreNodesRegistry.UnregisterNodeFromRegistry(restoreNode); err != nil {
n.logger.Error("Failed to unregister node from registry", "error", err)
}
return
case <-ticker.C:
n.sendHeartbeat(&restoreNode)
}
}
}
func (n *RestorerNode) IsRestorerRunning() bool {
return n.lastHeartbeat.After(time.Now().UTC().Add(-restorerHealthcheckThreshold))
}
func (n *RestorerNode) MakeRestore(restoreID uuid.UUID) {
// Get and delete cached DB credentials atomically
dbCache := n.cacheUtil.GetAndDelete(restoreID.String())
if dbCache == nil {
// Cache miss - fail immediately
restore, err := n.restoreRepository.FindByID(restoreID)
if err != nil {
n.logger.Error(
"Failed to get restore by ID after cache miss",
"restoreId",
restoreID,
"error",
err,
)
return
}
errMsg := "Database credentials expired or missing from cache (most likely due to instance restart)"
restore.FailMessage = &errMsg
restore.Status = restores_core.RestoreStatusFailed
if err := n.restoreRepository.Save(restore); err != nil {
n.logger.Error("Failed to save restore after cache miss", "error", err)
}
n.logger.Error("Restore failed: cache miss", "restoreId", restoreID)
return
}
restore, err := n.restoreRepository.FindByID(restoreID)
if err != nil {
n.logger.Error("Failed to get restore by ID", "restoreId", restoreID, "error", err)
return
}
backup, err := n.backupService.GetBackup(restore.BackupID)
if err != nil {
n.logger.Error("Failed to get backup by ID", "backupId", restore.BackupID, "error", err)
return
}
databaseID := backup.DatabaseID
database, err := n.databaseService.GetDatabaseByID(databaseID)
if err != nil {
n.logger.Error("Failed to get database by ID", "databaseId", databaseID, "error", err)
return
}
backupConfig, err := n.backupConfigService.GetBackupConfigByDbId(databaseID)
if err != nil {
n.logger.Error("Failed to get backup config by database ID", "error", err)
return
}
if backupConfig.StorageID == nil {
n.logger.Error("Backup config storage ID is not defined")
return
}
storage, err := n.storageService.GetStorageByID(*backupConfig.StorageID)
if err != nil {
n.logger.Error("Failed to get storage by ID", "error", err)
return
}
start := time.Now().UTC()
// Create restoring database from cached credentials
restoringToDB := &databases.Database{
Type: database.Type,
Postgresql: dbCache.PostgresqlDatabase,
Mysql: dbCache.MysqlDatabase,
Mariadb: dbCache.MariadbDatabase,
Mongodb: dbCache.MongodbDatabase,
}
if err := restoringToDB.PopulateDbData(n.logger, n.fieldEncryptor); err != nil {
errMsg := fmt.Sprintf("failed to auto-detect database data: %v", err)
restore.FailMessage = &errMsg
restore.Status = restores_core.RestoreStatusFailed
restore.RestoreDurationMs = time.Since(start).Milliseconds()
if err := n.restoreRepository.Save(restore); err != nil {
n.logger.Error("Failed to save restore", "error", err)
}
return
}
isExcludeExtensions := false
if dbCache.PostgresqlDatabase != nil {
isExcludeExtensions = dbCache.PostgresqlDatabase.IsExcludeExtensions
}
err = n.restoreBackupUsecase.Execute(
backupConfig,
*restore,
database,
restoringToDB,
backup,
storage,
isExcludeExtensions,
)
if err != nil {
errMsg := err.Error()
n.logger.Error("Restore execution failed",
"restoreId", restore.ID,
"backupId", backup.ID,
"databaseId", databaseID,
"databaseType", database.Type,
"storageId", storage.ID,
"storageType", storage.Type,
"error", err,
"errorMessage", errMsg,
)
restore.FailMessage = &errMsg
restore.Status = restores_core.RestoreStatusFailed
restore.RestoreDurationMs = time.Since(start).Milliseconds()
if err := n.restoreRepository.Save(restore); err != nil {
n.logger.Error("Failed to save restore", "error", err)
}
return
}
restore.Status = restores_core.RestoreStatusCompleted
restore.RestoreDurationMs = time.Since(start).Milliseconds()
if err := n.restoreRepository.Save(restore); err != nil {
n.logger.Error("Failed to save restore", "error", err)
return
}
n.logger.Info(
"Restore completed successfully",
"restoreId", restore.ID,
"backupId", backup.ID,
"durationMs", restore.RestoreDurationMs,
)
}
func (n *RestorerNode) sendHeartbeat(restoreNode *RestoreNode) {
n.lastHeartbeat = time.Now().UTC()
if err := n.restoreNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *restoreNode); err != nil {
n.logger.Error("Failed to send heartbeat", "error", err)
}
}

View File

@@ -0,0 +1,163 @@
package restoring
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/postgresql"
"databasus-backend/internal/features/notifiers"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
cache_utils "databasus-backend/internal/util/cache"
)
func Test_MakeRestore_WhenCacheMissed_RestoreFails(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
defer func() {
backupRepo := backups_core.BackupRepository{}
backupsList, _ := backupRepo.FindByDatabaseID(database.ID)
for _, backup := range backupsList {
backupRepo.DeleteByID(backup.ID)
}
restoreRepo := restores_core.RestoreRepository{}
restoresInProgress, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
for _, restore := range restoresInProgress {
restoreRepo.DeleteByID(restore.ID)
}
restoresFailed, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusFailed)
for _, restore := range restoresFailed {
restoreRepo.DeleteByID(restore.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
cache_utils.ClearAllCache()
}()
backup := backups.CreateTestBackup(database.ID, storage.ID)
// Create restore but DON'T cache DB credentials
// Also don't set embedded DB fields to avoid schema issues
restore := &restores_core.Restore{
BackupID: backup.ID,
Status: restores_core.RestoreStatusInProgress,
}
err := restoreRepository.Save(restore)
assert.NoError(t, err)
// Create restorer and execute restore (should fail due to cache miss)
restorerNode := CreateTestRestorerNode()
restorerNode.MakeRestore(restore.ID)
// Verify restore failed with appropriate error message
updatedRestore, err := restoreRepository.FindByID(restore.ID)
assert.NoError(t, err)
assert.Equal(t, restores_core.RestoreStatusFailed, updatedRestore.Status)
assert.NotNil(t, updatedRestore.FailMessage)
assert.Contains(
t,
*updatedRestore.FailMessage,
"Database credentials expired or missing from cache",
)
}
func Test_MakeRestore_WhenTaskStarts_CacheDeletedImmediately(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
defer func() {
backupRepo := backups_core.BackupRepository{}
backupsList, _ := backupRepo.FindByDatabaseID(database.ID)
for _, backup := range backupsList {
backupRepo.DeleteByID(backup.ID)
}
restoreRepo := restores_core.RestoreRepository{}
restoresInProgress, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
for _, restore := range restoresInProgress {
restoreRepo.DeleteByID(restore.ID)
}
restoresFailed, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusFailed)
for _, restore := range restoresFailed {
restoreRepo.DeleteByID(restore.ID)
}
restoresCompleted, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusCompleted)
for _, restore := range restoresCompleted {
restoreRepo.DeleteByID(restore.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
cache_utils.ClearAllCache()
}()
backup := backups.CreateTestBackup(database.ID, storage.ID)
// Create restore with cached DB credentials
// Don't set embedded DB fields in the restore model itself
restore := &restores_core.Restore{
BackupID: backup.ID,
Status: restores_core.RestoreStatusInProgress,
}
err := restoreRepository.Save(restore)
assert.NoError(t, err)
// Cache DB credentials separately
dbCache := &RestoreDatabaseCache{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Host: "localhost",
Port: 5432,
Username: "test",
Password: "test",
Database: stringPtr("testdb"),
Version: "16",
},
}
restoreDatabaseCache.SetWithExpiration(restore.ID.String(), dbCache, 1*time.Hour)
// Verify cache exists before restore starts
cachedDB := restoreDatabaseCache.Get(restore.ID.String())
assert.NotNil(t, cachedDB, "Cache should exist before restore starts")
// Start restore (this will call GetAndDelete)
restorerNode := CreateTestRestorerNode()
restorerNode.MakeRestore(restore.ID)
// Verify cache was deleted immediately
cachedDBAfter := restoreDatabaseCache.Get(restore.ID.String())
assert.Nil(t, cachedDBAfter, "Cache should be deleted immediately when task starts")
}

View File

@@ -0,0 +1,395 @@
package restoring
import (
"context"
"fmt"
"log/slog"
"time"
"github.com/google/uuid"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_config "databasus-backend/internal/features/backups/config"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
cache_utils "databasus-backend/internal/util/cache"
)
const (
schedulerStartupDelay = 1 * time.Minute
schedulerTickerInterval = 1 * time.Minute
schedulerHealthcheckThreshold = 5 * time.Minute
)
type RestoresScheduler struct {
restoreRepository *restores_core.RestoreRepository
backupService *backups.BackupService
storageService *storages.StorageService
backupConfigService *backups_config.BackupConfigService
restoreNodesRegistry *RestoreNodesRegistry
lastCheckTime time.Time
logger *slog.Logger
restoreToNodeRelations map[uuid.UUID]RestoreToNodeRelation
restorerNode *RestorerNode
cacheUtil *cache_utils.CacheUtil[RestoreDatabaseCache]
completionSubscriptionID uuid.UUID
}
func (s *RestoresScheduler) Run(ctx context.Context) {
s.lastCheckTime = time.Now().UTC()
if config.GetEnv().IsManyNodesMode {
// wait other nodes to start
time.Sleep(schedulerStartupDelay)
}
if err := s.failRestoresInProgress(); err != nil {
s.logger.Error("Failed to fail restores in progress", "error", err)
panic(err)
}
err := s.restoreNodesRegistry.SubscribeForRestoresCompletions(s.onRestoreCompleted)
if err != nil {
s.logger.Error("Failed to subscribe to restore completions", "error", err)
panic(err)
}
defer func() {
if err := s.restoreNodesRegistry.UnsubscribeForRestoresCompletions(); err != nil {
s.logger.Error("Failed to unsubscribe from restore completions", "error", err)
}
}()
if ctx.Err() != nil {
return
}
ticker := time.NewTicker(schedulerTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.checkDeadNodesAndFailRestores(); err != nil {
s.logger.Error("Failed to check dead nodes and fail restores", "error", err)
}
s.lastCheckTime = time.Now().UTC()
}
}
}
func (s *RestoresScheduler) IsSchedulerRunning() bool {
return s.lastCheckTime.After(time.Now().UTC().Add(-schedulerHealthcheckThreshold))
}
func (s *RestoresScheduler) failRestoresInProgress() error {
restoresInProgress, err := s.restoreRepository.FindByStatus(
restores_core.RestoreStatusInProgress,
)
if err != nil {
return err
}
for _, restore := range restoresInProgress {
failMessage := "Restore failed due to application restart"
restore.FailMessage = &failMessage
restore.Status = restores_core.RestoreStatusFailed
if err := s.restoreRepository.Save(restore); err != nil {
return err
}
}
return nil
}
func (s *RestoresScheduler) StartRestore(restoreID uuid.UUID, dbCache *RestoreDatabaseCache) error {
// If dbCache not provided, try to fetch from DB (for backward compatibility/testing)
if dbCache == nil {
restore, err := s.restoreRepository.FindByID(restoreID)
if err != nil {
s.logger.Error(
"Failed to find restore by ID",
"restoreId",
restoreID,
"error",
err,
)
return err
}
// Create cache DTO from restore (may be nil if not in DB)
dbCache = &RestoreDatabaseCache{
PostgresqlDatabase: restore.PostgresqlDatabase,
MysqlDatabase: restore.MysqlDatabase,
MariadbDatabase: restore.MariadbDatabase,
MongodbDatabase: restore.MongodbDatabase,
}
}
// Cache database credentials with 1-hour expiration
s.cacheUtil.SetWithExpiration(restoreID.String(), dbCache, 1*time.Hour)
leastBusyNodeID, err := s.calculateLeastBusyNode()
if err != nil {
s.logger.Error(
"Failed to calculate least busy node",
"restoreId",
restoreID,
"error",
err,
)
return err
}
if err := s.restoreNodesRegistry.IncrementRestoresInProgress(*leastBusyNodeID); err != nil {
s.logger.Error(
"Failed to increment restores in progress",
"nodeId",
leastBusyNodeID,
"restoreId",
restoreID,
"error",
err,
)
return err
}
if err := s.restoreNodesRegistry.AssignRestoreToNode(*leastBusyNodeID, restoreID, false); err != nil {
s.logger.Error(
"Failed to submit restore",
"nodeId",
leastBusyNodeID,
"restoreId",
restoreID,
"error",
err,
)
if decrementErr := s.restoreNodesRegistry.DecrementRestoresInProgress(*leastBusyNodeID); decrementErr != nil {
s.logger.Error(
"Failed to decrement restores in progress after submit failure",
"nodeId",
leastBusyNodeID,
"error",
decrementErr,
)
}
return err
}
if relation, exists := s.restoreToNodeRelations[*leastBusyNodeID]; exists {
relation.RestoreIDs = append(relation.RestoreIDs, restoreID)
s.restoreToNodeRelations[*leastBusyNodeID] = relation
} else {
s.restoreToNodeRelations[*leastBusyNodeID] = RestoreToNodeRelation{
NodeID: *leastBusyNodeID,
RestoreIDs: []uuid.UUID{restoreID},
}
}
s.logger.Info(
"Successfully triggered restore",
"restoreId",
restoreID,
"nodeId",
leastBusyNodeID,
)
return nil
}
func (s *RestoresScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
nodes, err := s.restoreNodesRegistry.GetAvailableNodes()
if err != nil {
return nil, fmt.Errorf("failed to get available nodes: %w", err)
}
if len(nodes) == 0 {
return nil, fmt.Errorf("no nodes available")
}
stats, err := s.restoreNodesRegistry.GetRestoreNodesStats()
if err != nil {
return nil, fmt.Errorf("failed to get restore nodes stats: %w", err)
}
statsMap := make(map[uuid.UUID]int)
for _, stat := range stats {
statsMap[stat.ID] = stat.ActiveRestores
}
var bestNode *RestoreNode
var bestScore float64 = -1
for i := range nodes {
node := &nodes[i]
activeRestores := statsMap[node.ID]
var score float64
if node.ThroughputMBs > 0 {
score = float64(activeRestores) / float64(node.ThroughputMBs)
} else {
score = float64(activeRestores) * 1000
}
if bestNode == nil || score < bestScore {
bestNode = node
bestScore = score
}
}
if bestNode == nil {
return nil, fmt.Errorf("no suitable nodes available")
}
return &bestNode.ID, nil
}
func (s *RestoresScheduler) onRestoreCompleted(nodeID uuid.UUID, restoreID uuid.UUID) {
// Verify this task is actually a restore (registry contains multiple task types)
_, err := s.restoreRepository.FindByID(restoreID)
if err != nil {
// Not a restore task, ignore it
return
}
relation, exists := s.restoreToNodeRelations[nodeID]
if !exists {
s.logger.Warn(
"Received completion for unknown node",
"nodeId",
nodeID,
"restoreId",
restoreID,
)
return
}
newRestoreIDs := make([]uuid.UUID, 0)
found := false
for _, id := range relation.RestoreIDs {
if id == restoreID {
found = true
continue
}
newRestoreIDs = append(newRestoreIDs, id)
}
if !found {
s.logger.Warn(
"Restore not found in node's restore list",
"nodeId",
nodeID,
"restoreId",
restoreID,
)
return
}
if len(newRestoreIDs) == 0 {
delete(s.restoreToNodeRelations, nodeID)
} else {
relation.RestoreIDs = newRestoreIDs
s.restoreToNodeRelations[nodeID] = relation
}
if err := s.restoreNodesRegistry.DecrementRestoresInProgress(nodeID); err != nil {
s.logger.Error(
"Failed to decrement restores in progress",
"nodeId",
nodeID,
"restoreId",
restoreID,
"error",
err,
)
}
}
func (s *RestoresScheduler) checkDeadNodesAndFailRestores() error {
nodes, err := s.restoreNodesRegistry.GetAvailableNodes()
if err != nil {
return fmt.Errorf("failed to get available nodes: %w", err)
}
aliveNodeIDs := make(map[uuid.UUID]bool)
for _, node := range nodes {
aliveNodeIDs[node.ID] = true
}
for nodeID, relation := range s.restoreToNodeRelations {
if aliveNodeIDs[nodeID] {
continue
}
s.logger.Warn(
"Node is dead, failing its restores",
"nodeId",
nodeID,
"restoreCount",
len(relation.RestoreIDs),
)
for _, restoreID := range relation.RestoreIDs {
restore, err := s.restoreRepository.FindByID(restoreID)
if err != nil {
s.logger.Error(
"Failed to find restore for dead node",
"nodeId",
nodeID,
"restoreId",
restoreID,
"error",
err,
)
continue
}
failMessage := "Restore failed due to node unavailability"
restore.FailMessage = &failMessage
restore.Status = restores_core.RestoreStatusFailed
if err := s.restoreRepository.Save(restore); err != nil {
s.logger.Error(
"Failed to save failed restore for dead node",
"nodeId",
nodeID,
"restoreId",
restoreID,
"error",
err,
)
continue
}
if err := s.restoreNodesRegistry.DecrementRestoresInProgress(nodeID); err != nil {
s.logger.Error(
"Failed to decrement restores in progress for dead node",
"nodeId",
nodeID,
"restoreId",
restoreID,
"error",
err,
)
}
s.logger.Info(
"Failed restore due to dead node",
"nodeId",
nodeID,
"restoreId",
restoreID,
)
}
delete(s.restoreToNodeRelations, nodeID)
}
return nil
}

View File

@@ -0,0 +1,852 @@
package restoring
import (
"testing"
"time"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/postgresql"
"databasus-backend/internal/features/notifiers"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/encryption"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_CheckDeadNodesAndFailRestores_NodeDies_FailsRestoreAndCleansUpRegistry(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
var mockNodeID uuid.UUID
defer func() {
backupRepo := backups_core.BackupRepository{}
backups, _ := backupRepo.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepo.DeleteByID(backup.ID)
}
restoreRepo := restores_core.RestoreRepository{}
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
for _, restore := range restores {
restoreRepo.DeleteByID(restore.ID)
}
restores, _ = restoreRepo.FindByStatus(restores_core.RestoreStatusFailed)
for _, restore := range restores {
restoreRepo.DeleteByID(restore.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
// Clean up mock node
if mockNodeID != uuid.Nil {
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: mockNodeID})
}
cache_utils.ClearAllCache()
}()
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
// Create a test backup
backup := backups.CreateTestBackup(database.ID, storage.ID)
var err error
// Register mock node without subscribing to restores (simulates node crash after registration)
mockNodeID = uuid.New()
err = CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC())
assert.NoError(t, err)
// Create restore and assign to mock node
restore := &restores_core.Restore{
BackupID: backup.ID,
Status: restores_core.RestoreStatusInProgress,
}
err = restoreRepository.Save(restore)
assert.NoError(t, err)
// Scheduler assigns restore to mock node
err = GetRestoresScheduler().StartRestore(restore.ID, nil)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
// Verify Valkey counter was incremented when restore was assigned
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
assert.NoError(t, err)
foundStat := false
for _, stat := range stats {
if stat.ID == mockNodeID {
assert.Equal(t, 1, stat.ActiveRestores)
foundStat = true
break
}
}
assert.True(t, foundStat, "Node stats should be present")
// Simulate node death by setting heartbeat older than 2-minute threshold
oldHeartbeat := time.Now().UTC().Add(-3 * time.Minute)
err = UpdateNodeHeartbeatDirectly(mockNodeID, 100, oldHeartbeat)
assert.NoError(t, err)
// Trigger dead node detection
err = GetRestoresScheduler().checkDeadNodesAndFailRestores()
assert.NoError(t, err)
// Verify restore was failed with appropriate error message
failedRestore, err := restoreRepository.FindByID(restore.ID)
assert.NoError(t, err)
assert.Equal(t, restores_core.RestoreStatusFailed, failedRestore.Status)
assert.NotNil(t, failedRestore.FailMessage)
assert.Contains(t, *failedRestore.FailMessage, "node unavailability")
// Verify Valkey counter was decremented after restore failed
stats, err = restoreNodesRegistry.GetRestoreNodesStats()
assert.NoError(t, err)
for _, stat := range stats {
if stat.ID == mockNodeID {
assert.Equal(t, 0, stat.ActiveRestores)
}
}
time.Sleep(200 * time.Millisecond)
}
func Test_OnRestoreCompleted_TaskIsNotRestore_SkipsProcessing(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
var mockNodeID uuid.UUID
defer func() {
backupRepo := backups_core.BackupRepository{}
backups, _ := backupRepo.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepo.DeleteByID(backup.ID)
}
restoreRepo := restores_core.RestoreRepository{}
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
for _, restore := range restores {
restoreRepo.DeleteByID(restore.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
// Clean up mock node
if mockNodeID != uuid.Nil {
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: mockNodeID})
}
cache_utils.ClearAllCache()
}()
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
// Create a test backup
backup := backups.CreateTestBackup(database.ID, storage.ID)
// Register mock node
mockNodeID = uuid.New()
err := CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC())
assert.NoError(t, err)
// Create restore and assign to the node
restore := &restores_core.Restore{
BackupID: backup.ID,
Status: restores_core.RestoreStatusInProgress,
}
err = restoreRepository.Save(restore)
assert.NoError(t, err)
err = GetRestoresScheduler().StartRestore(restore.ID, nil)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
// Get initial state of the registry
initialStats, err := restoreNodesRegistry.GetRestoreNodesStats()
assert.NoError(t, err)
var initialActiveTasks int
for _, stat := range initialStats {
if stat.ID == mockNodeID {
initialActiveTasks = stat.ActiveRestores
break
}
}
assert.Equal(t, 1, initialActiveTasks, "Should have 1 active task")
// Call onRestoreCompleted with a random UUID (not a restore ID)
nonRestoreTaskID := uuid.New()
GetRestoresScheduler().onRestoreCompleted(mockNodeID, nonRestoreTaskID)
time.Sleep(100 * time.Millisecond)
// Verify: Active tasks counter should remain the same (not decremented)
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
assert.NoError(t, err)
for _, stat := range stats {
if stat.ID == mockNodeID {
assert.Equal(t, initialActiveTasks, stat.ActiveRestores,
"Active tasks should not change for non-restore task")
}
}
// Verify: restore should still be in progress (not modified)
unchangedRestore, err := restoreRepository.FindByID(restore.ID)
assert.NoError(t, err)
assert.Equal(t, restores_core.RestoreStatusInProgress, unchangedRestore.Status,
"Restore status should not change for non-restore task completion")
// Verify: restoreToNodeRelations should still contain the node
scheduler := GetRestoresScheduler()
_, exists := scheduler.restoreToNodeRelations[mockNodeID]
assert.True(t, exists, "Node should still be in restoreToNodeRelations")
time.Sleep(200 * time.Millisecond)
}
func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
t.Run("Nodes with same throughput", func(t *testing.T) {
cache_utils.ClearAllCache()
node1ID := uuid.New()
node2ID := uuid.New()
node3ID := uuid.New()
now := time.Now().UTC()
defer func() {
// Clean up all mock nodes
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: node1ID})
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: node2ID})
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: node3ID})
cache_utils.ClearAllCache()
}()
err := CreateMockNodeInRegistry(node1ID, 100, now)
assert.NoError(t, err)
err = CreateMockNodeInRegistry(node2ID, 100, now)
assert.NoError(t, err)
err = CreateMockNodeInRegistry(node3ID, 100, now)
assert.NoError(t, err)
for range 5 {
err = restoreNodesRegistry.IncrementRestoresInProgress(node1ID)
assert.NoError(t, err)
}
for range 2 {
err = restoreNodesRegistry.IncrementRestoresInProgress(node2ID)
assert.NoError(t, err)
}
for range 8 {
err = restoreNodesRegistry.IncrementRestoresInProgress(node3ID)
assert.NoError(t, err)
}
leastBusyNodeID, err := GetRestoresScheduler().calculateLeastBusyNode()
assert.NoError(t, err)
assert.NotNil(t, leastBusyNodeID)
assert.Equal(t, node2ID, *leastBusyNodeID)
})
t.Run("Nodes with different throughput", func(t *testing.T) {
cache_utils.ClearAllCache()
node100MBsID := uuid.New()
node50MBsID := uuid.New()
now := time.Now().UTC()
defer func() {
// Clean up all mock nodes
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: node100MBsID})
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: node50MBsID})
cache_utils.ClearAllCache()
}()
err := CreateMockNodeInRegistry(node100MBsID, 100, now)
assert.NoError(t, err)
err = CreateMockNodeInRegistry(node50MBsID, 50, now)
assert.NoError(t, err)
for range 10 {
err = restoreNodesRegistry.IncrementRestoresInProgress(node100MBsID)
assert.NoError(t, err)
}
err = restoreNodesRegistry.IncrementRestoresInProgress(node50MBsID)
assert.NoError(t, err)
leastBusyNodeID, err := GetRestoresScheduler().calculateLeastBusyNode()
assert.NoError(t, err)
assert.NotNil(t, leastBusyNodeID)
assert.Equal(t, node50MBsID, *leastBusyNodeID)
})
}
func Test_FailRestoresInProgress_SchedulerStarts_UpdatesStatus(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backupRepo := backups_core.BackupRepository{}
backups, _ := backupRepo.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepo.DeleteByID(backup.ID)
}
restoreRepo := restores_core.RestoreRepository{}
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
for _, restore := range restores {
restoreRepo.DeleteByID(restore.ID)
}
restores, _ = restoreRepo.FindByStatus(restores_core.RestoreStatusFailed)
for _, restore := range restores {
restoreRepo.DeleteByID(restore.ID)
}
restores, _ = restoreRepo.FindByStatus(restores_core.RestoreStatusCompleted)
for _, restore := range restores {
restoreRepo.DeleteByID(restore.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
cache_utils.ClearAllCache()
}()
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
// Create a test backup
backup := backups.CreateTestBackup(database.ID, storage.ID)
// Create two in-progress restores that should be failed on scheduler restart
restore1 := &restores_core.Restore{
BackupID: backup.ID,
Status: restores_core.RestoreStatusInProgress,
CreatedAt: time.Now().UTC().Add(-30 * time.Minute),
}
err := restoreRepository.Save(restore1)
assert.NoError(t, err)
restore2 := &restores_core.Restore{
BackupID: backup.ID,
Status: restores_core.RestoreStatusInProgress,
CreatedAt: time.Now().UTC().Add(-15 * time.Minute),
}
err = restoreRepository.Save(restore2)
assert.NoError(t, err)
// Create a completed restore to verify it's not affected by failRestoresInProgress
completedRestore := &restores_core.Restore{
BackupID: backup.ID,
Status: restores_core.RestoreStatusCompleted,
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
}
err = restoreRepository.Save(completedRestore)
assert.NoError(t, err)
// Trigger the scheduler's failRestoresInProgress logic
// This should mark in-progress restores as failed
err = GetRestoresScheduler().failRestoresInProgress()
assert.NoError(t, err)
// Verify all restores exist and were processed correctly
allRestores1, err := restoreRepository.FindByID(restore1.ID)
assert.NoError(t, err)
allRestores2, err := restoreRepository.FindByID(restore2.ID)
assert.NoError(t, err)
allRestores3, err := restoreRepository.FindByID(completedRestore.ID)
assert.NoError(t, err)
var failedCount int
var completedCount int
restoresToCheck := []*restores_core.Restore{allRestores1, allRestores2, allRestores3}
for _, restore := range restoresToCheck {
switch restore.Status {
case restores_core.RestoreStatusFailed:
failedCount++
// Verify fail message indicates application restart
assert.NotNil(t, restore.FailMessage)
assert.Equal(t, "Restore failed due to application restart", *restore.FailMessage)
case restores_core.RestoreStatusCompleted:
completedCount++
}
}
// Verify correct number of restores in each state
assert.Equal(t, 2, failedCount, "Should have 2 failed restores (originally in progress)")
assert.Equal(t, 1, completedCount, "Should have 1 completed restore (unchanged)")
time.Sleep(200 * time.Millisecond)
}
func Test_StartRestore_RestoreCompletes_DecrementsActiveTaskCount(t *testing.T) {
cache_utils.ClearAllCache()
// Start scheduler so it can handle task completions
schedulerCancel := StartSchedulerForTest(t)
defer schedulerCancel()
restorerNode := CreateTestRestorerNode()
restorerNode.restoreBackupUsecase = &MockSuccessRestoreUsecase{}
cancel := StartRestorerNodeForTest(t, restorerNode)
defer StopRestorerNodeForTest(t, cancel, restorerNode)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backupRepo := backups_core.BackupRepository{}
backups, _ := backupRepo.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepo.DeleteByID(backup.ID)
}
restoreRepo := restores_core.RestoreRepository{}
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusCompleted)
for _, restore := range restores {
restoreRepo.DeleteByID(restore.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
// Create a test backup
backup := backups.CreateTestBackup(database.ID, storage.ID)
// Get initial active task count
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
assert.NoError(t, err)
var initialActiveTasks int
for _, stat := range stats {
if stat.ID == restorerNode.nodeID {
initialActiveTasks = stat.ActiveRestores
break
}
}
t.Logf("Initial active tasks: %d", initialActiveTasks)
// Create and start restore
restore := &restores_core.Restore{
BackupID: backup.ID,
Status: restores_core.RestoreStatusInProgress,
}
err = restoreRepository.Save(restore)
assert.NoError(t, err)
err = GetRestoresScheduler().StartRestore(restore.ID, nil)
assert.NoError(t, err)
// Wait for restore to complete
WaitForRestoreCompletion(t, restore.ID, 10*time.Second)
// Verify restore was completed
completedRestore, err := restoreRepository.FindByID(restore.ID)
assert.NoError(t, err)
assert.Equal(t, restores_core.RestoreStatusCompleted, completedRestore.Status)
// Wait for active task count to decrease
decreased := WaitForActiveTasksDecrease(
t,
restorerNode.nodeID,
initialActiveTasks+1,
10*time.Second,
)
assert.True(t, decreased, "Active task count should have decreased after restore completion")
// Verify final active task count equals initial count
finalStats, err := restoreNodesRegistry.GetRestoreNodesStats()
assert.NoError(t, err)
for _, stat := range finalStats {
if stat.ID == restorerNode.nodeID {
t.Logf("Final active tasks: %d", stat.ActiveRestores)
assert.Equal(t, initialActiveTasks, stat.ActiveRestores,
"Active task count should return to initial value after restore completion")
break
}
}
time.Sleep(200 * time.Millisecond)
}
func Test_StartRestore_RestoreFails_DecrementsActiveTaskCount(t *testing.T) {
cache_utils.ClearAllCache()
// Start scheduler so it can handle task completions
schedulerCancel := StartSchedulerForTest(t)
defer schedulerCancel()
restorerNode := CreateTestRestorerNode()
restorerNode.restoreBackupUsecase = &MockFailedRestoreUsecase{}
cancel := StartRestorerNodeForTest(t, restorerNode)
defer StopRestorerNodeForTest(t, cancel, restorerNode)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backupRepo := backups_core.BackupRepository{}
backups, _ := backupRepo.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepo.DeleteByID(backup.ID)
}
restoreRepo := restores_core.RestoreRepository{}
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusFailed)
for _, restore := range restores {
restoreRepo.DeleteByID(restore.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
// Create a test backup
backup := backups.CreateTestBackup(database.ID, storage.ID)
// Get initial active task count
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
assert.NoError(t, err)
var initialActiveTasks int
for _, stat := range stats {
if stat.ID == restorerNode.nodeID {
initialActiveTasks = stat.ActiveRestores
break
}
}
t.Logf("Initial active tasks: %d", initialActiveTasks)
// Create and start restore
restore := &restores_core.Restore{
BackupID: backup.ID,
Status: restores_core.RestoreStatusInProgress,
}
err = restoreRepository.Save(restore)
assert.NoError(t, err)
err = GetRestoresScheduler().StartRestore(restore.ID, nil)
assert.NoError(t, err)
// Wait for restore to fail
WaitForRestoreCompletion(t, restore.ID, 10*time.Second)
// Verify restore failed
failedRestore, err := restoreRepository.FindByID(restore.ID)
assert.NoError(t, err)
assert.Equal(t, restores_core.RestoreStatusFailed, failedRestore.Status)
// Wait for active task count to decrease
decreased := WaitForActiveTasksDecrease(
t,
restorerNode.nodeID,
initialActiveTasks+1,
10*time.Second,
)
assert.True(t, decreased, "Active task count should have decreased after restore failure")
// Verify final active task count equals initial count
finalStats, err := restoreNodesRegistry.GetRestoreNodesStats()
assert.NoError(t, err)
for _, stat := range finalStats {
if stat.ID == restorerNode.nodeID {
t.Logf("Final active tasks: %d", stat.ActiveRestores)
assert.Equal(t, initialActiveTasks, stat.ActiveRestores,
"Active task count should return to initial value after restore failure")
break
}
}
time.Sleep(200 * time.Millisecond)
}
func Test_StartRestore_CredentialsStoredEncryptedInCache(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
var mockNodeID uuid.UUID
defer func() {
backupRepo := backups_core.BackupRepository{}
backups, _ := backupRepo.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepo.DeleteByID(backup.ID)
}
restoreRepo := restores_core.RestoreRepository{}
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
for _, restore := range restores {
restoreRepo.DeleteByID(restore.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
// Clean up mock node
if mockNodeID != uuid.Nil {
restoreNodesRegistry.UnregisterNodeFromRegistry(RestoreNode{ID: mockNodeID})
}
cache_utils.ClearAllCache()
}()
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
// Create a test backup
backup := backups.CreateTestBackup(database.ID, storage.ID)
// Register mock node so scheduler can assign restore to it
mockNodeID = uuid.New()
err := CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC())
assert.NoError(t, err)
// Create restore with plaintext credentials
plaintextPassword := "test_password_123"
restore := &restores_core.Restore{
BackupID: backup.ID,
Status: restores_core.RestoreStatusInProgress,
}
err = restoreRepository.Save(restore)
assert.NoError(t, err)
// Create PostgreSQL database credentials with plaintext password
postgresDB := &postgresql.PostgresqlDatabase{
Host: "localhost",
Port: 5432,
Username: "testuser",
Password: plaintextPassword,
Database: stringPtr("testdb"),
Version: "16",
}
// Encrypt password using FieldEncryptor (same as production flow)
encryptor := encryption.GetFieldEncryptor()
err = postgresDB.EncryptSensitiveFields(database.ID, encryptor)
assert.NoError(t, err)
// Verify password was encrypted (different from plaintext)
assert.NotEqual(t, plaintextPassword, postgresDB.Password,
"Password should be encrypted, not plaintext")
// Create cache with encrypted credentials
dbCache := &RestoreDatabaseCache{
PostgresqlDatabase: postgresDB,
}
// Call StartRestore to cache credentials (do NOT start restore node)
err = GetRestoresScheduler().StartRestore(restore.ID, dbCache)
assert.NoError(t, err)
// Directly read from cache
cachedData := restoreDatabaseCache.Get(restore.ID.String())
assert.NotNil(t, cachedData, "Cache entry should exist")
assert.NotNil(t, cachedData.PostgresqlDatabase, "PostgreSQL credentials should be cached")
// Verify password in cache is encrypted (not plaintext)
assert.NotEqual(t, plaintextPassword, cachedData.PostgresqlDatabase.Password,
"Cached password should be encrypted, not plaintext")
assert.Equal(t, postgresDB.Password, cachedData.PostgresqlDatabase.Password,
"Cached password should match the encrypted version")
// Verify other fields are present
assert.Equal(t, "localhost", cachedData.PostgresqlDatabase.Host)
assert.Equal(t, 5432, cachedData.PostgresqlDatabase.Port)
assert.Equal(t, "testuser", cachedData.PostgresqlDatabase.Username)
assert.Equal(t, "testdb", *cachedData.PostgresqlDatabase.Database)
time.Sleep(200 * time.Millisecond)
}
func Test_StartRestore_CredentialsRemovedAfterRestoreStarts(t *testing.T) {
cache_utils.ClearAllCache()
// Start scheduler so it can handle task assignments
schedulerCancel := StartSchedulerForTest(t)
defer schedulerCancel()
// Create mock restorer node with credential capture usecase
restorerNode := CreateTestRestorerNode()
calledChan := make(chan *databases.Database, 1)
restorerNode.restoreBackupUsecase = &MockCaptureCredentialsRestoreUsecase{
CalledChan: calledChan,
ShouldSucceed: true,
}
cancel := StartRestorerNodeForTest(t, restorerNode)
defer StopRestorerNodeForTest(t, cancel, restorerNode)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backupRepo := backups_core.BackupRepository{}
backups, _ := backupRepo.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepo.DeleteByID(backup.ID)
}
restoreRepo := restores_core.RestoreRepository{}
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusCompleted)
for _, restore := range restores {
restoreRepo.DeleteByID(restore.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
cache_utils.ClearAllCache()
}()
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
// Create a test backup
backup := backups.CreateTestBackup(database.ID, storage.ID)
// Create restore with credentials
plaintextPassword := "test_password_456"
restore := &restores_core.Restore{
BackupID: backup.ID,
Status: restores_core.RestoreStatusInProgress,
}
err := restoreRepository.Save(restore)
assert.NoError(t, err)
// Create PostgreSQL database credentials
// Database field is nil to avoid PopulateDbData trying to connect
postgresDB := &postgresql.PostgresqlDatabase{
Host: "localhost",
Port: 5432,
Username: "testuser",
Password: plaintextPassword,
Database: nil,
Version: "16",
}
// Encrypt password (same as production flow)
encryptor := encryption.GetFieldEncryptor()
err = postgresDB.EncryptSensitiveFields(database.ID, encryptor)
assert.NoError(t, err)
encryptedPassword := postgresDB.Password
// Create cache with encrypted credentials
dbCache := &RestoreDatabaseCache{
PostgresqlDatabase: postgresDB,
}
// Call StartRestore to cache credentials and trigger restore
err = GetRestoresScheduler().StartRestore(restore.ID, dbCache)
assert.NoError(t, err)
// Wait for mock usecase to be called (with timeout)
var capturedDB *databases.Database
select {
case capturedDB = <-calledChan:
t.Log("Mock usecase was called, credentials captured")
case <-time.After(10 * time.Second):
t.Fatal("Timeout waiting for mock usecase to be called")
}
// Verify cache is empty after restore starts (credentials were deleted)
cacheAfterExecution := restoreDatabaseCache.Get(restore.ID.String())
assert.Nil(t, cacheAfterExecution, "Cache should be empty after restore execution starts")
// Verify mock received valid credentials
assert.NotNil(t, capturedDB, "Captured database should not be nil")
assert.NotNil(t, capturedDB.Postgresql, "PostgreSQL credentials should be provided to usecase")
assert.Equal(t, "localhost", capturedDB.Postgresql.Host)
assert.Equal(t, 5432, capturedDB.Postgresql.Port)
assert.Equal(t, "testuser", capturedDB.Postgresql.Username)
assert.NotEmpty(t, capturedDB.Postgresql.Password, "Password should be provided to usecase")
// Note: Password at this point may still be encrypted because PopulateDbData
// is called after the mock captures it. The important thing is that credentials
// were provided to the usecase despite cache being deleted.
t.Logf("Encrypted password in cache: %s", encryptedPassword)
t.Logf("Password received by usecase: %s", capturedDB.Postgresql.Password)
// Wait for restore to complete
WaitForRestoreCompletion(t, restore.ID, 10*time.Second)
// Verify restore was completed
completedRestore, err := restoreRepository.FindByID(restore.ID)
assert.NoError(t, err)
assert.Equal(t, restores_core.RestoreStatusCompleted, completedRestore.Status)
time.Sleep(200 * time.Millisecond)
}

View File

@@ -0,0 +1,297 @@
package restoring
import (
"context"
"fmt"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/postgresql"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/restores/usecases"
"databasus-backend/internal/features/storages"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
)
func CreateTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
backups_config.GetBackupConfigController(),
)
return router
}
func CreateTestRestorerNode() *RestorerNode {
return &RestorerNode{
uuid.New(),
databases.GetDatabaseService(),
backups.GetBackupService(),
encryption.GetFieldEncryptor(),
restoreRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
restoreNodesRegistry,
logger.GetLogger(),
usecases.GetRestoreBackupUsecase(),
restoreDatabaseCache,
time.Time{},
}
}
// WaitForRestoreCompletion waits for a restore to be completed (or failed)
func WaitForRestoreCompletion(
t *testing.T,
restoreID uuid.UUID,
timeout time.Duration,
) {
deadline := time.Now().UTC().Add(timeout)
for time.Now().UTC().Before(deadline) {
restore, err := restoreRepository.FindByID(restoreID)
if err != nil {
t.Logf("WaitForRestoreCompletion: error finding restore: %v", err)
time.Sleep(50 * time.Millisecond)
continue
}
t.Logf("WaitForRestoreCompletion: restore status: %s", restore.Status)
if restore.Status == restores_core.RestoreStatusCompleted ||
restore.Status == restores_core.RestoreStatusFailed {
t.Logf(
"WaitForRestoreCompletion: restore finished with status %s",
restore.Status,
)
return
}
time.Sleep(50 * time.Millisecond)
}
t.Logf("WaitForRestoreCompletion: timeout waiting for restore to complete")
}
// StartRestorerNodeForTest starts a RestorerNode in a goroutine for testing.
// The node registers itself in the registry and subscribes to restore assignments.
// Returns a context cancel function that should be deferred to stop the node.
func StartRestorerNodeForTest(t *testing.T, restorerNode *RestorerNode) context.CancelFunc {
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
restorerNode.Run(ctx)
close(done)
}()
// Poll registry for node presence instead of fixed sleep
deadline := time.Now().UTC().Add(5 * time.Second)
for time.Now().UTC().Before(deadline) {
nodes, err := restoreNodesRegistry.GetAvailableNodes()
if err == nil {
for _, node := range nodes {
if node.ID == restorerNode.nodeID {
t.Logf("RestorerNode registered in registry: %s", restorerNode.nodeID)
return func() {
cancel()
select {
case <-done:
t.Log("RestorerNode stopped gracefully")
case <-time.After(2 * time.Second):
t.Log("RestorerNode stop timeout")
}
}
}
}
}
time.Sleep(50 * time.Millisecond)
}
t.Fatalf("RestorerNode failed to register in registry within timeout")
return nil
}
// StartSchedulerForTest starts the RestoresScheduler in a goroutine for testing.
// The scheduler subscribes to task completions and manages restore lifecycle.
// Returns a context cancel function that should be deferred to stop the scheduler.
func StartSchedulerForTest(t *testing.T) context.CancelFunc {
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
GetRestoresScheduler().Run(ctx)
close(done)
}()
// Give scheduler time to subscribe to completions
time.Sleep(100 * time.Millisecond)
t.Log("RestoresScheduler started")
return func() {
cancel()
select {
case <-done:
t.Log("RestoresScheduler stopped gracefully")
case <-time.After(2 * time.Second):
t.Log("RestoresScheduler stop timeout")
}
}
}
// StopRestorerNodeForTest stops the RestorerNode by canceling its context.
// It waits for the node to unregister from the registry.
func StopRestorerNodeForTest(t *testing.T, cancel context.CancelFunc, restorerNode *RestorerNode) {
cancel()
// Wait for node to unregister from registry
deadline := time.Now().UTC().Add(2 * time.Second)
for time.Now().UTC().Before(deadline) {
nodes, err := restoreNodesRegistry.GetAvailableNodes()
if err == nil {
found := false
for _, node := range nodes {
if node.ID == restorerNode.nodeID {
found = true
break
}
}
if !found {
t.Logf("RestorerNode unregistered from registry: %s", restorerNode.nodeID)
return
}
}
time.Sleep(50 * time.Millisecond)
}
t.Logf("RestorerNode stop completed for %s", restorerNode.nodeID)
}
func CreateMockNodeInRegistry(nodeID uuid.UUID, throughputMBs int, lastHeartbeat time.Time) error {
restoreNode := RestoreNode{
ID: nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: lastHeartbeat,
}
return restoreNodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, restoreNode)
}
func UpdateNodeHeartbeatDirectly(
nodeID uuid.UUID,
throughputMBs int,
lastHeartbeat time.Time,
) error {
restoreNode := RestoreNode{
ID: nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: lastHeartbeat,
}
return restoreNodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, restoreNode)
}
func GetNodeFromRegistry(nodeID uuid.UUID) (*RestoreNode, error) {
nodes, err := restoreNodesRegistry.GetAvailableNodes()
if err != nil {
return nil, err
}
for _, node := range nodes {
if node.ID == nodeID {
return &node, nil
}
}
return nil, fmt.Errorf("node not found")
}
// WaitForActiveTasksDecrease waits for the active task count to decrease below the initial count.
// It polls the registry every 500ms until the count decreases or the timeout is reached.
// Returns true if the count decreased, false if timeout was reached.
func WaitForActiveTasksDecrease(
t *testing.T,
nodeID uuid.UUID,
initialCount int,
timeout time.Duration,
) bool {
deadline := time.Now().UTC().Add(timeout)
for time.Now().UTC().Before(deadline) {
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
if err != nil {
t.Logf("WaitForActiveTasksDecrease: error getting node stats: %v", err)
time.Sleep(500 * time.Millisecond)
continue
}
for _, stat := range stats {
if stat.ID == nodeID {
t.Logf(
"WaitForActiveTasksDecrease: current active tasks = %d (initial = %d)",
stat.ActiveRestores,
initialCount,
)
if stat.ActiveRestores < initialCount {
t.Logf(
"WaitForActiveTasksDecrease: active tasks decreased from %d to %d",
initialCount,
stat.ActiveRestores,
)
return true
}
break
}
}
time.Sleep(500 * time.Millisecond)
}
t.Logf("WaitForActiveTasksDecrease: timeout waiting for active tasks to decrease")
return false
}
// CreateTestRestore creates a test restore with the given backup and status
func CreateTestRestore(
t *testing.T,
backup *backups_core.Backup,
status restores_core.RestoreStatus,
) *restores_core.Restore {
restore := &restores_core.Restore{
BackupID: backup.ID,
Status: status,
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Host: "localhost",
Port: 5432,
Username: "test",
Password: "test",
Database: stringPtr("testdb"),
Version: "16",
},
}
err := restoreRepository.Save(restore)
if err != nil {
t.Fatalf("Failed to create test restore: %v", err)
}
return restore
}
func stringPtr(s string) *string {
return &s
}

View File

@@ -7,8 +7,8 @@ import (
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/disk"
"databasus-backend/internal/features/restores/enums"
"databasus-backend/internal/features/restores/models"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/restores/restoring"
"databasus-backend/internal/features/restores/usecases"
"databasus-backend/internal/features/storages"
users_models "databasus-backend/internal/features/users/models"
@@ -25,7 +25,7 @@ import (
type RestoreService struct {
backupService *backups.BackupService
restoreRepository *RestoreRepository
restoreRepository *restores_core.RestoreRepository
storageService *storages.StorageService
backupConfigService *backups_config.BackupConfigService
restoreBackupUsecase *usecases.RestoreBackupUsecase
@@ -44,7 +44,7 @@ func (s *RestoreService) OnBeforeBackupRemove(backup *backups_core.Backup) error
}
for _, restore := range restores {
if restore.Status == enums.RestoreStatusInProgress {
if restore.Status == restores_core.RestoreStatusInProgress {
return errors.New("restore is in progress, backup cannot be removed")
}
}
@@ -61,7 +61,7 @@ func (s *RestoreService) OnBeforeBackupRemove(backup *backups_core.Backup) error
func (s *RestoreService) GetRestores(
user *users_models.User,
backupID uuid.UUID,
) ([]*models.Restore, error) {
) ([]*restores_core.Restore, error) {
backup, err := s.backupService.GetBackup(backupID)
if err != nil {
return nil, err
@@ -93,7 +93,7 @@ func (s *RestoreService) GetRestores(
func (s *RestoreService) RestoreBackupWithAuth(
user *users_models.User,
backupID uuid.UUID,
requestDTO RestoreBackupRequest,
requestDTO restores_core.RestoreBackupRequest,
) error {
backup, err := s.backupService.GetBackup(backupID)
if err != nil {
@@ -134,11 +134,45 @@ func (s *RestoreService) RestoreBackupWithAuth(
return err
}
go func() {
if err := s.RestoreBackup(backup, requestDTO); err != nil {
s.logger.Error("Failed to restore backup", "error", err)
// Create restore record with the request configuration
restore := restores_core.Restore{
ID: uuid.New(),
Status: restores_core.RestoreStatusInProgress,
BackupID: backup.ID,
Backup: backup,
CreatedAt: time.Now().UTC(),
RestoreDurationMs: 0,
FailMessage: nil,
PostgresqlDatabase: requestDTO.PostgresqlDatabase,
MysqlDatabase: requestDTO.MysqlDatabase,
MariadbDatabase: requestDTO.MariadbDatabase,
MongodbDatabase: requestDTO.MongodbDatabase,
}
if err := s.restoreRepository.Save(&restore); err != nil {
return err
}
// Prepare database cache with credentials from the request
dbCache := &restoring.RestoreDatabaseCache{
PostgresqlDatabase: requestDTO.PostgresqlDatabase,
MysqlDatabase: requestDTO.MysqlDatabase,
MariadbDatabase: requestDTO.MariadbDatabase,
MongodbDatabase: requestDTO.MongodbDatabase,
}
// Trigger restore via scheduler
scheduler := restoring.GetRestoresScheduler()
if err := scheduler.StartRestore(restore.ID, dbCache); err != nil {
// Mark restore as failed if we can't schedule it
failMsg := fmt.Sprintf("Failed to schedule restore: %v", err)
restore.FailMessage = &failMsg
restore.Status = restores_core.RestoreStatusFailed
if saveErr := s.restoreRepository.Save(&restore); saveErr != nil {
s.logger.Error("Failed to save restore after scheduling error", "error", saveErr)
}
}()
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
@@ -153,127 +187,9 @@ func (s *RestoreService) RestoreBackupWithAuth(
return nil
}
func (s *RestoreService) RestoreBackup(
backup *backups_core.Backup,
requestDTO RestoreBackupRequest,
) error {
if backup.Status != backups_core.BackupStatusCompleted {
return errors.New("backup is not completed")
}
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
if err != nil {
return err
}
switch database.Type {
case databases.DatabaseTypePostgres:
if requestDTO.PostgresqlDatabase == nil {
return errors.New("postgresql database is required")
}
case databases.DatabaseTypeMysql:
if requestDTO.MysqlDatabase == nil {
return errors.New("mysql database is required")
}
case databases.DatabaseTypeMariadb:
if requestDTO.MariadbDatabase == nil {
return errors.New("mariadb database is required")
}
case databases.DatabaseTypeMongodb:
if requestDTO.MongodbDatabase == nil {
return errors.New("mongodb database is required")
}
}
restore := models.Restore{
ID: uuid.New(),
Status: enums.RestoreStatusInProgress,
BackupID: backup.ID,
Backup: backup,
CreatedAt: time.Now().UTC(),
RestoreDurationMs: 0,
FailMessage: nil,
}
// Save the restore first
if err := s.restoreRepository.Save(&restore); err != nil {
return err
}
// Save the restore again to include the postgresql database
if err := s.restoreRepository.Save(&restore); err != nil {
return err
}
storage, err := s.storageService.GetStorageByID(backup.StorageID)
if err != nil {
return err
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(
database.ID,
)
if err != nil {
return err
}
start := time.Now().UTC()
restoringToDB := &databases.Database{
Type: database.Type,
Postgresql: requestDTO.PostgresqlDatabase,
Mysql: requestDTO.MysqlDatabase,
Mariadb: requestDTO.MariadbDatabase,
Mongodb: requestDTO.MongodbDatabase,
}
if err := restoringToDB.PopulateDbData(s.logger, s.fieldEncryptor); err != nil {
return fmt.Errorf("failed to auto-detect database data: %w", err)
}
isExcludeExtensions := false
if requestDTO.PostgresqlDatabase != nil {
isExcludeExtensions = requestDTO.PostgresqlDatabase.IsExcludeExtensions
}
err = s.restoreBackupUsecase.Execute(
backupConfig,
restore,
database,
restoringToDB,
backup,
storage,
isExcludeExtensions,
)
if err != nil {
errMsg := err.Error()
restore.FailMessage = &errMsg
restore.Status = enums.RestoreStatusFailed
restore.RestoreDurationMs = time.Since(start).Milliseconds()
if err := s.restoreRepository.Save(&restore); err != nil {
return err
}
return err
}
restore.Status = enums.RestoreStatusCompleted
restore.RestoreDurationMs = time.Since(start).Milliseconds()
if err := s.restoreRepository.Save(&restore); err != nil {
return err
}
return nil
}
func (s *RestoreService) validateVersionCompatibility(
backupDatabase *databases.Database,
requestDTO RestoreBackupRequest,
requestDTO restores_core.RestoreBackupRequest,
) error {
// populate version
if requestDTO.MariadbDatabase != nil {
@@ -372,7 +288,7 @@ func (s *RestoreService) validateVersionCompatibility(
func (s *RestoreService) validateDiskSpace(
backup *backups_core.Backup,
requestDTO RestoreBackupRequest,
requestDTO restores_core.RestoreBackupRequest,
) error {
// Only validate disk space for PostgreSQL when file-based restore is needed:
// - CPU > 1 (parallel jobs require file)

View File

@@ -0,0 +1,51 @@
package restores
import (
"context"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"databasus-backend/internal/features/backups/backups"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/restores/restoring"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
)
func CreateTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
backups_config.GetBackupConfigController(),
backups.GetBackupController(),
GetRestoreController(),
)
v1 := router.Group("/api/v1")
backups.GetBackupController().RegisterPublicRoutes(v1)
return router
}
func SetupMockRestoreNode(t *testing.T) (uuid.UUID, context.CancelFunc) {
nodeID := uuid.New()
err := restoring.CreateMockNodeInRegistry(
nodeID,
100,
time.Now().UTC(),
)
if err != nil {
t.Fatalf("Failed to create mock node: %v", err)
}
cleanup := func() {
// Node will expire naturally from registry
}
return nodeID, cleanup
}

View File

@@ -24,7 +24,7 @@ import (
"databasus-backend/internal/features/databases"
mariadbtypes "databasus-backend/internal/features/databases/databases/mariadb"
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
"databasus-backend/internal/features/restores/models"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
util_encryption "databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/tools"
@@ -39,7 +39,7 @@ func (uc *RestoreMariadbBackupUsecase) Execute(
originalDB *databases.Database,
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
restore models.Restore,
restore restores_core.Restore,
backup *backups_core.Backup,
storage *storages.Storage,
) error {

View File

@@ -20,7 +20,7 @@ import (
"databasus-backend/internal/features/databases"
mongodbtypes "databasus-backend/internal/features/databases/databases/mongodb"
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
"databasus-backend/internal/features/restores/models"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
util_encryption "databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/tools"
@@ -39,7 +39,7 @@ func (uc *RestoreMongodbBackupUsecase) Execute(
originalDB *databases.Database,
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
restore models.Restore,
restore restores_core.Restore,
backup *backups_core.Backup,
storage *storages.Storage,
) error {

View File

@@ -24,7 +24,7 @@ import (
"databasus-backend/internal/features/databases"
mysqltypes "databasus-backend/internal/features/databases/databases/mysql"
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
"databasus-backend/internal/features/restores/models"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
util_encryption "databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/tools"
@@ -39,7 +39,7 @@ func (uc *RestoreMysqlBackupUsecase) Execute(
originalDB *databases.Database,
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
restore models.Restore,
restore restores_core.Restore,
backup *backups_core.Backup,
storage *storages.Storage,
) error {

View File

@@ -21,7 +21,7 @@ import (
"databasus-backend/internal/features/databases"
pgtypes "databasus-backend/internal/features/databases/databases/postgresql"
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
"databasus-backend/internal/features/restores/models"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
util_encryption "databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/tools"
@@ -38,7 +38,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
originalDB *databases.Database,
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
restore models.Restore,
restore restores_core.Restore,
backup *backups_core.Backup,
storage *storages.Storage,
isExcludeExtensions bool,

View File

@@ -6,7 +6,7 @@ import (
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/restores/models"
restores_core "databasus-backend/internal/features/restores/core"
usecases_mariadb "databasus-backend/internal/features/restores/usecases/mariadb"
usecases_mongodb "databasus-backend/internal/features/restores/usecases/mongodb"
usecases_mysql "databasus-backend/internal/features/restores/usecases/mysql"
@@ -23,7 +23,7 @@ type RestoreBackupUsecase struct {
func (uc *RestoreBackupUsecase) Execute(
backupConfig *backups_config.BackupConfig,
restore models.Restore,
restore restores_core.Restore,
originalDB *databases.Database,
restoringToDB *databases.Database,
backup *backups_core.Backup,

View File

@@ -37,7 +37,7 @@ func (s *HealthcheckService) IsHealthy() error {
}
}
if config.GetEnv().IsBackupNode {
if config.GetEnv().IsProcessingNode {
if !s.backuperNode.IsBackuperRunning() {
return errors.New("backuper node is not running for more than 5 minutes")
}

View File

@@ -1,18 +0,0 @@
package task_registry
import (
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/logger"
)
var taskNodesRegistry = &TaskNodesRegistry{
cache_utils.GetValkeyClient(),
logger.GetLogger(),
cache_utils.DefaultCacheTimeout,
cache_utils.NewPubSubManager(),
cache_utils.NewPubSubManager(),
}
func GetTaskNodesRegistry() *TaskNodesRegistry {
return taskNodesRegistry
}

View File

@@ -1,29 +0,0 @@
package task_registry
import (
"time"
"github.com/google/uuid"
)
type TaskNode struct {
ID uuid.UUID `json:"id"`
ThroughputMBs int `json:"throughputMBs"`
LastHeartbeat time.Time `json:"lastHeartbeat"`
}
type TaskNodeStats struct {
ID uuid.UUID `json:"id"`
ActiveTasks int `json:"activeTasks"`
}
type TaskSubmitMessage struct {
NodeID string `json:"nodeId"`
TaskID string `json:"taskId"`
IsCallNotifier bool `json:"isCallNotifier"`
}
type TaskCompletionMessage struct {
NodeID string `json:"nodeId"`
TaskID string `json:"taskId"`
}

View File

@@ -21,9 +21,7 @@ import (
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
mariadbtypes "databasus-backend/internal/features/databases/databases/mariadb"
"databasus-backend/internal/features/restores"
restores_enums "databasus-backend/internal/features/restores/enums"
restores_models "databasus-backend/internal/features/restores/models"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
@@ -213,7 +211,7 @@ func testMariadbBackupRestoreForVersion(
)
restore := waitForMariadbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
var tableExists int
err = newDB.Get(
@@ -311,7 +309,7 @@ func testMariadbBackupRestoreWithEncryptionForVersion(
)
restore := waitForMariadbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
var tableExists int
err = newDB.Get(
@@ -418,7 +416,7 @@ func testMariadbBackupRestoreWithReadOnlyUserForVersion(
)
restore := waitForMariadbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
var tableExists int
err = newDB.Get(
@@ -506,7 +504,7 @@ func createMariadbRestoreViaAPI(
version tools.MariadbVersion,
token string,
) {
request := restores.RestoreBackupRequest{
request := restores_core.RestoreBackupRequest{
MariadbDatabase: &mariadbtypes.MariadbDatabase{
Host: host,
Port: port,
@@ -533,7 +531,7 @@ func waitForMariadbRestoreCompletion(
backupID uuid.UUID,
token string,
timeout time.Duration,
) *restores_models.Restore {
) *restores_core.Restore {
startTime := time.Now()
pollInterval := 500 * time.Millisecond
@@ -542,7 +540,7 @@ func waitForMariadbRestoreCompletion(
t.Fatalf("Timeout waiting for MariaDB restore completion after %v", timeout)
}
var restoresList []*restores_models.Restore
var restoresList []*restores_core.Restore
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
@@ -553,10 +551,10 @@ func waitForMariadbRestoreCompletion(
)
for _, restore := range restoresList {
if restore.Status == restores_enums.RestoreStatusCompleted {
if restore.Status == restores_core.RestoreStatusCompleted {
return restore
}
if restore.Status == restores_enums.RestoreStatusFailed {
if restore.Status == restores_core.RestoreStatusFailed {
failMsg := "unknown error"
if restore.FailMessage != nil {
failMsg = *restore.FailMessage

View File

@@ -23,9 +23,7 @@ import (
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
mongodbtypes "databasus-backend/internal/features/databases/databases/mongodb"
"databasus-backend/internal/features/restores"
restores_enums "databasus-backend/internal/features/restores/enums"
restores_models "databasus-backend/internal/features/restores/models"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
@@ -175,7 +173,7 @@ func testMongodbBackupRestoreForVersion(
)
restore := waitForMongodbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
verifyMongodbDataIntegrity(t, container, newDBName)
@@ -254,7 +252,7 @@ func testMongodbBackupRestoreWithEncryptionForVersion(
)
restore := waitForMongodbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
verifyMongodbDataIntegrity(t, container, newDBName)
@@ -342,7 +340,7 @@ func testMongodbBackupRestoreWithReadOnlyUserForVersion(
)
restore := waitForMongodbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
verifyMongodbDataIntegrity(t, container, newDBName)
@@ -431,7 +429,7 @@ func createMongodbRestoreViaAPI(
version tools.MongodbVersion,
token string,
) {
request := restores.RestoreBackupRequest{
request := restores_core.RestoreBackupRequest{
MongodbDatabase: &mongodbtypes.MongodbDatabase{
Host: host,
Port: port,
@@ -461,7 +459,7 @@ func waitForMongodbRestoreCompletion(
backupID uuid.UUID,
token string,
timeout time.Duration,
) *restores_models.Restore {
) *restores_core.Restore {
startTime := time.Now()
pollInterval := 500 * time.Millisecond
@@ -470,7 +468,7 @@ func waitForMongodbRestoreCompletion(
t.Fatalf("Timeout waiting for MongoDB restore completion after %v", timeout)
}
var restoresList []*restores_models.Restore
var restoresList []*restores_core.Restore
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
@@ -481,10 +479,10 @@ func waitForMongodbRestoreCompletion(
)
for _, restore := range restoresList {
if restore.Status == restores_enums.RestoreStatusCompleted {
if restore.Status == restores_core.RestoreStatusCompleted {
return restore
}
if restore.Status == restores_enums.RestoreStatusFailed {
if restore.Status == restores_core.RestoreStatusFailed {
failMsg := "unknown error"
if restore.FailMessage != nil {
failMsg = *restore.FailMessage

View File

@@ -21,9 +21,7 @@ import (
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
mysqltypes "databasus-backend/internal/features/databases/databases/mysql"
"databasus-backend/internal/features/restores"
restores_enums "databasus-backend/internal/features/restores/enums"
restores_models "databasus-backend/internal/features/restores/models"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
@@ -188,7 +186,7 @@ func testMysqlBackupRestoreForVersion(t *testing.T, mysqlVersion tools.MysqlVers
)
restore := waitForMysqlRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
var tableExists int
err = newDB.Get(
@@ -286,7 +284,7 @@ func testMysqlBackupRestoreWithEncryptionForVersion(
)
restore := waitForMysqlRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
var tableExists int
err = newDB.Get(
@@ -393,7 +391,7 @@ func testMysqlBackupRestoreWithReadOnlyUserForVersion(
)
restore := waitForMysqlRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
var tableExists int
err = newDB.Get(
@@ -481,7 +479,7 @@ func createMysqlRestoreViaAPI(
version tools.MysqlVersion,
token string,
) {
request := restores.RestoreBackupRequest{
request := restores_core.RestoreBackupRequest{
MysqlDatabase: &mysqltypes.MysqlDatabase{
Host: host,
Port: port,
@@ -508,7 +506,7 @@ func waitForMysqlRestoreCompletion(
backupID uuid.UUID,
token string,
timeout time.Duration,
) *restores_models.Restore {
) *restores_core.Restore {
startTime := time.Now()
pollInterval := 500 * time.Millisecond
@@ -517,7 +515,7 @@ func waitForMysqlRestoreCompletion(
t.Fatalf("Timeout waiting for MySQL restore completion after %v", timeout)
}
var restoresList []*restores_models.Restore
var restoresList []*restores_core.Restore
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
@@ -528,10 +526,10 @@ func waitForMysqlRestoreCompletion(
)
for _, restore := range restoresList {
if restore.Status == restores_enums.RestoreStatusCompleted {
if restore.Status == restores_core.RestoreStatusCompleted {
return restore
}
if restore.Status == restores_enums.RestoreStatusFailed {
if restore.Status == restores_core.RestoreStatusFailed {
failMsg := "unknown error"
if restore.FailMessage != nil {
failMsg = *restore.FailMessage

View File

@@ -23,8 +23,7 @@ import (
"databasus-backend/internal/features/databases"
pgtypes "databasus-backend/internal/features/databases/databases/postgresql"
"databasus-backend/internal/features/restores"
restores_enums "databasus-backend/internal/features/restores/enums"
restores_models "databasus-backend/internal/features/restores/models"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
@@ -212,7 +211,7 @@ func Test_BackupAndRestoreSupabase_PublicSchemaOnly_RestoreIsSuccessful(t *testi
)
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
var countAfterRestore int
err = supabaseDB.Get(
@@ -439,7 +438,7 @@ func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string, cp
)
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
var tableExists bool
err = newDB.Get(
@@ -555,7 +554,7 @@ func testSchemaSelectionAllSchemasForVersion(t *testing.T, pgVersion string, por
)
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
var publicTableExists bool
err = newDB.Get(&publicTableExists, `
@@ -689,7 +688,7 @@ func testBackupRestoreWithExcludeExtensionsForVersion(t *testing.T, pgVersion st
)
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
// Verify the table was restored
var tableExists bool
@@ -829,7 +828,7 @@ func testBackupRestoreWithoutExcludeExtensionsForVersion(
)
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
// Verify the extension was recovered
var extensionExists bool
@@ -956,7 +955,7 @@ func testBackupRestoreWithReadOnlyUserForVersion(t *testing.T, pgVersion string,
)
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
var tableExists bool
err = newDB.Get(
@@ -1076,7 +1075,7 @@ func testSchemaSelectionOnlySpecifiedSchemasForVersion(
)
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
var publicTableExists bool
err = newDB.Get(&publicTableExists, `
@@ -1190,7 +1189,7 @@ func testBackupRestoreWithEncryptionForVersion(t *testing.T, pgVersion string, p
)
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
var tableExists bool
err = newDB.Get(
@@ -1286,7 +1285,7 @@ func waitForRestoreCompletion(
backupID uuid.UUID,
token string,
timeout time.Duration,
) *restores_models.Restore {
) *restores_core.Restore {
startTime := time.Now()
pollInterval := 500 * time.Millisecond
@@ -1295,7 +1294,7 @@ func waitForRestoreCompletion(
t.Fatalf("Timeout waiting for restore completion after %v", timeout)
}
var restores []*restores_models.Restore
var restores []*restores_core.Restore
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
@@ -1306,10 +1305,10 @@ func waitForRestoreCompletion(
)
for _, restore := range restores {
if restore.Status == restores_enums.RestoreStatusCompleted {
if restore.Status == restores_core.RestoreStatusCompleted {
return restore
}
if restore.Status == restores_enums.RestoreStatusFailed {
if restore.Status == restores_core.RestoreStatusFailed {
failMsg := "unknown error"
if restore.FailMessage != nil {
failMsg = *restore.FailMessage
@@ -1476,7 +1475,7 @@ func createRestoreWithCpuCountViaAPI(
cpuCount int,
token string,
) {
request := restores.RestoreBackupRequest{
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &pgtypes.PostgresqlDatabase{
Host: host,
Port: port,
@@ -1509,7 +1508,7 @@ func createRestoreWithOptionsViaAPI(
isExcludeExtensions bool,
token string,
) {
request := restores.RestoreBackupRequest{
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &pgtypes.PostgresqlDatabase{
Host: host,
Port: port,
@@ -1647,7 +1646,7 @@ func createSupabaseRestoreViaAPI(
database string,
token string,
) {
request := restores.RestoreBackupRequest{
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &pgtypes.PostgresqlDatabase{
Host: host,
Port: port,

View File

@@ -5,6 +5,7 @@ import (
"testing"
"databasus-backend/internal/features/backups/backups/backuping"
"databasus-backend/internal/features/restores/restoring"
cache_utils "databasus-backend/internal/util/cache"
)
@@ -12,11 +13,15 @@ func TestMain(m *testing.M) {
cache_utils.ClearAllCache()
backuperNode := backuping.CreateTestBackuperNode()
cancel := backuping.StartBackuperNodeForTest(&testing.T{}, backuperNode)
cancelBackup := backuping.StartBackuperNodeForTest(&testing.T{}, backuperNode)
restorerNode := restoring.CreateTestRestorerNode()
cancelRestore := restoring.StartRestorerNodeForTest(&testing.T{}, restorerNode)
exitCode := m.Run()
backuping.StopBackuperNodeForTest(&testing.T{}, cancel, backuperNode)
backuping.StopBackuperNodeForTest(&testing.T{}, cancelBackup, backuperNode)
restoring.StopRestorerNodeForTest(&testing.T{}, cancelRestore, restorerNode)
os.Exit(exitCode)
}

View File

@@ -1,7 +1,9 @@
package cache_utils
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
@@ -49,3 +51,43 @@ func Test_ClearAllCache_AfterClear_CacheIsEmpty(t *testing.T) {
assert.Nil(t, retrieved, "Key %s should be deleted after clearing", tk.prefix+tk.key)
}
}
func Test_SetWithExpiration_SetsCorrectTTL(t *testing.T) {
client := getCache()
// Create a cache utility
testPrefix := "test:ttl:"
cacheUtil := NewCacheUtil[string](client, testPrefix)
// Set a value with 1-hour expiration
testKey := "key1"
testValue := "test value"
oneHour := 1 * time.Hour
cacheUtil.SetWithExpiration(testKey, &testValue, oneHour)
// Verify the value was set
retrieved := cacheUtil.Get(testKey)
assert.NotNil(t, retrieved, "Value should be stored")
assert.Equal(t, testValue, *retrieved, "Retrieved value should match")
// Check the TTL using Valkey TTL command
ctx, cancel := context.WithTimeout(context.Background(), DefaultCacheTimeout)
defer cancel()
fullKey := testPrefix + testKey
ttlResult := client.Do(ctx, client.B().Ttl().Key(fullKey).Build())
assert.NoError(t, ttlResult.Error(), "TTL command should not error")
ttlSeconds, err := ttlResult.AsInt64()
assert.NoError(t, err, "TTL should be retrievable as int64")
// TTL should be approximately 1 hour (3600 seconds)
// Allow for a small margin (within 10 seconds of 3600)
expectedTTL := int64(3600)
assert.GreaterOrEqual(t, ttlSeconds, expectedTTL-10, "TTL should be close to 1 hour")
assert.LessOrEqual(t, ttlSeconds, expectedTTL, "TTL should not exceed 1 hour")
// Clean up
cacheUtil.Invalidate(testKey)
}

View File

@@ -67,6 +67,43 @@ func (c *CacheUtil[T]) Set(key string, item *T) {
c.client.Do(ctx, c.client.B().Set().Key(fullKey).Value(string(data)).Ex(c.expiry).Build())
}
func (c *CacheUtil[T]) SetWithExpiration(key string, item *T, expiry time.Duration) {
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
defer cancel()
data, err := json.Marshal(item)
if err != nil {
return
}
fullKey := c.prefix + key
c.client.Do(ctx, c.client.B().Set().Key(fullKey).Value(string(data)).Ex(expiry).Build())
}
func (c *CacheUtil[T]) GetAndDelete(key string) *T {
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
defer cancel()
fullKey := c.prefix + key
result := c.client.Do(ctx, c.client.B().Getdel().Key(fullKey).Build())
if result.Error() != nil {
return nil
}
data, err := result.AsBytes()
if err != nil {
return nil
}
var item T
if err := json.Unmarshal(data, &item); err != nil {
return nil
}
return &item
}
func (c *CacheUtil[T]) Invalidate(key string) {
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
defer cancel()