FEATURE (restores): Add cancellation of restore process

This commit is contained in:
Rostislav Dugin
2026-01-17 22:35:47 +03:00
parent c39bd34d5e
commit f957abc9db
44 changed files with 1625 additions and 380 deletions

188
AGENTS.md
View File

@@ -237,6 +237,66 @@ func (c *ProjectController) extractProjectID(ctx *gin.Context) uuid.UUID {
---
### Boolean Naming
**Always prefix boolean variables with verbs like `is`, `has`, `was`, `should`, `can`, etc.**
This makes the code more readable and clearly indicates that the variable represents a true/false state.
#### Good Examples:
```go
type User struct {
IsActive bool
IsVerified bool
HasAccess bool
WasNotified bool
}
type BackupConfig struct {
IsEnabled bool
ShouldCompress bool
CanRetry bool
}
// Variables
isInProgress := true
wasCompleted := false
hasPermission := checkPermissions()
```
#### Bad Examples:
```go
type User struct {
Active bool // Should be: IsActive
Verified bool // Should be: IsVerified
Access bool // Should be: HasAccess
}
type BackupConfig struct {
Enabled bool // Should be: IsEnabled
Compress bool // Should be: ShouldCompress
Retry bool // Should be: CanRetry
}
// Variables
inProgress := true // Should be: isInProgress
completed := false // Should be: wasCompleted
permission := true // Should be: hasPermission
```
#### Common Boolean Prefixes:
- **is** - current state (IsActive, IsValid, IsEnabled)
- **has** - possession or presence (HasAccess, HasPermission, HasError)
- **was** - past state (WasCompleted, WasNotified, WasDeleted)
- **should** - intention or recommendation (ShouldRetry, ShouldCompress)
- **can** - capability or permission (CanRetry, CanDelete, CanEdit)
- **will** - future state (WillExpire, WillRetry)
---
### Comments
#### Guidelines
@@ -489,6 +549,134 @@ func GetOrderRepository() *repositories.OrderRepository {
}
```
#### SetupDependencies() Pattern
**All `SetupDependencies()` functions must use sync.Once to ensure idempotent execution.**
This pattern allows `SetupDependencies()` to be safely called multiple times (especially in tests) while ensuring the actual setup logic executes only once.
**Implementation Pattern:**
```go
package feature
import (
"sync"
"sync/atomic"
"databasus-backend/internal/util/logger"
)
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
// Initialize dependencies here
someService.SetDependency(otherService)
anotherService.AddListener(listener)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}
```
**Why This Pattern:**
- **Tests can call multiple times**: Test setup often calls `SetupDependencies()` multiple times without issues
- **Thread-safe**: Works correctly with concurrent calls (nanoseconds or seconds apart)
- **Idempotent**: Subsequent calls are safe, only log warning
- **No panics**: Does not break tests or production code on multiple calls
**Key Points:**
1. Check `isSetup.Load()` **before** calling `Do()` to detect previous executions
2. Set `isSetup.Store(true)` **inside** the `Do()` closure after setup completes
3. Log warning if already setup (helps identify unnecessary duplicate calls)
4. All setup logic must be inside the `Do()` closure
---
### Background Services
**All background service `Run()` methods must panic if called multiple times to prevent corrupted states.**
Background services run infinite loops and must never be started twice on the same instance. Multiple calls indicate a serious bug that would cause duplicate goroutines, resource leaks, and data corruption.
**Implementation Pattern:**
```go
package feature
import (
"context"
"fmt"
"sync"
"sync/atomic"
)
type BackgroundService struct {
// ... existing fields ...
runOnce sync.Once
hasRun atomic.Bool
}
func (s *BackgroundService) Run(ctx context.Context) {
wasAlreadyRun := s.hasRun.Load()
s.runOnce.Do(func() {
s.hasRun.Store(true)
// Existing infinite loop logic
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.doWork()
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
}
```
**Why Panic Instead of Warning:**
- **Prevents corruption**: Multiple `Run()` calls would create duplicate goroutines consuming resources
- **Fails fast**: Catches critical bugs immediately in tests and production
- **Clear indication**: Panic clearly indicates a serious programming error
- **Applies everywhere**: Same protection in tests and production
**When This Applies:**
- All background services with infinite loops
- Registry services (BackupNodesRegistry, RestoreNodesRegistry)
- Scheduler services (BackupsScheduler, RestoresScheduler)
- Worker nodes (BackuperNode, RestorerNode)
- Cleanup services (AuditLogBackgroundService, DownloadTokenBackgroundService)
**Key Points:**
1. Check `hasRun.Load()` **before** calling `Do()` to detect previous executions
2. Set `hasRun.Store(true)` **inside** the `Do()` closure before starting work
3. **Always panic** if already run (never just log warning)
4. All run logic must be inside the `Do()` closure
5. This pattern is **thread-safe** for any timing (concurrent or sequential calls)
---
### Migrations

View File

@@ -2,34 +2,50 @@ package audit_logs
import (
"context"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
)
type AuditLogBackgroundService struct {
auditLogService *AuditLogService
logger *slog.Logger
runOnce sync.Once
hasRun atomic.Bool
}
func (s *AuditLogBackgroundService) Run(ctx context.Context) {
s.logger.Info("Starting audit log cleanup background service")
wasAlreadyRun := s.hasRun.Load()
if ctx.Err() != nil {
return
}
s.runOnce.Do(func() {
s.hasRun.Store(true)
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
s.logger.Info("Starting audit log cleanup background service")
for {
select {
case <-ctx.Done():
if ctx.Err() != nil {
return
case <-ticker.C:
if err := s.cleanOldAuditLogs(); err != nil {
s.logger.Error("Failed to clean old audit logs", "error", err)
}
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.cleanOldAuditLogs(); err != nil {
s.logger.Error("Failed to clean old audit logs", "error", err)
}
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
}

View File

@@ -1,6 +1,9 @@
package audit_logs
import (
"sync"
"sync/atomic"
users_services "databasus-backend/internal/features/users/services"
"databasus-backend/internal/util/logger"
)
@@ -14,8 +17,10 @@ var auditLogController = &AuditLogController{
auditLogService,
}
var auditLogBackgroundService = &AuditLogBackgroundService{
auditLogService,
logger.GetLogger(),
auditLogService: auditLogService,
logger: logger.GetLogger(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
func GetAuditLogService() *AuditLogService {
@@ -30,8 +35,23 @@ func GetAuditLogBackgroundService() *AuditLogBackgroundService {
return auditLogBackgroundService
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
users_services.GetUserService().SetAuditLogWriter(auditLogService)
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
users_services.GetUserService().SetAuditLogWriter(auditLogService)
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -2,6 +2,17 @@ package backuping
import (
"context"
"errors"
"fmt"
"log/slog"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"databasus-backend/internal/config"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
@@ -10,14 +21,6 @@ import (
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
workspaces_services "databasus-backend/internal/features/workspaces/services"
util_encryption "databasus-backend/internal/util/encryption"
"errors"
"fmt"
"log/slog"
"slices"
"strings"
"time"
"github.com/google/uuid"
)
const (
@@ -40,66 +43,79 @@ type BackuperNode struct {
nodeID uuid.UUID
lastHeartbeat time.Time
runOnce sync.Once
hasRun atomic.Bool
}
func (n *BackuperNode) Run(ctx context.Context) {
n.lastHeartbeat = time.Now().UTC()
wasAlreadyRun := n.hasRun.Load()
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
n.runOnce.Do(func() {
n.hasRun.Store(true)
backupNode := BackupNode{
ID: n.nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: time.Now().UTC(),
}
n.lastHeartbeat = time.Now().UTC()
if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
n.logger.Error("Failed to register node in registry", "error", err)
panic(err)
}
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
n.MakeBackup(backupID, isCallNotifier)
if err := n.backupNodesRegistry.PublishBackupCompletion(n.nodeID, backupID); err != nil {
n.logger.Error(
"Failed to publish backup completion",
"error",
err,
"backupID",
backupID,
)
backupNode := BackupNode{
ID: n.nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: time.Now().UTC(),
}
}
err := n.backupNodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID, backupHandler)
if err != nil {
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
panic(err)
}
defer func() {
if err := n.backupNodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
n.logger.Error("Failed to register node in registry", "error", err)
panic(err)
}
}()
ticker := time.NewTicker(heartbeatTickerInterval)
defer ticker.Stop()
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
for {
select {
case <-ctx.Done():
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
if err := n.backupNodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
n.logger.Error("Failed to unregister node from registry", "error", err)
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
n.MakeBackup(backupID, isCallNotifier)
if err := n.backupNodesRegistry.PublishBackupCompletion(n.nodeID, backupID); err != nil {
n.logger.Error(
"Failed to publish backup completion",
"error",
err,
"backupID",
backupID,
)
}
return
case <-ticker.C:
n.sendHeartbeat(&backupNode)
}
err := n.backupNodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID, backupHandler)
if err != nil {
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
panic(err)
}
defer func() {
if err := n.backupNodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
}
}()
ticker := time.NewTicker(heartbeatTickerInterval)
defer ticker.Stop()
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
for {
select {
case <-ctx.Done():
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
if err := n.backupNodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
n.logger.Error("Failed to unregister node from registry", "error", err)
}
return
case <-ticker.C:
n.sendHeartbeat(&backupNode)
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", n))
}
}

View File

@@ -1,6 +1,12 @@
package backuping
import (
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/usecases"
backups_config "databasus-backend/internal/features/backups/config"
@@ -12,9 +18,6 @@ import (
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
"time"
"github.com/google/uuid"
)
var backupRepository = &backups_core.BackupRepository{}
@@ -22,11 +25,13 @@ var backupRepository = &backups_core.BackupRepository{}
var taskCancelManager = tasks_cancellation.GetTaskCancelManager()
var backupNodesRegistry = &BackupNodesRegistry{
cache_utils.GetValkeyClient(),
logger.GetLogger(),
cache_utils.DefaultCacheTimeout,
cache_utils.NewPubSubManager(),
cache_utils.NewPubSubManager(),
client: cache_utils.GetValkeyClient(),
logger: logger.GetLogger(),
timeout: cache_utils.DefaultCacheTimeout,
pubsubBackups: cache_utils.NewPubSubManager(),
pubsubCompletions: cache_utils.NewPubSubManager(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
func getNodeID() uuid.UUID {
@@ -34,19 +39,21 @@ func getNodeID() uuid.UUID {
}
var backuperNode = &BackuperNode{
databases.GetDatabaseService(),
encryption.GetFieldEncryptor(),
workspaces_services.GetWorkspaceService(),
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
notifiers.GetNotifierService(),
taskCancelManager,
backupNodesRegistry,
logger.GetLogger(),
usecases.GetCreateBackupUsecase(),
getNodeID(),
time.Time{},
databaseService: databases.GetDatabaseService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
workspaceService: workspaces_services.GetWorkspaceService(),
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
notificationSender: notifiers.GetNotifierService(),
backupCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
logger: logger.GetLogger(),
createBackupUseCase: usecases.GetCreateBackupUsecase(),
nodeID: getNodeID(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
var backupsScheduler = &BackupsScheduler{
@@ -59,6 +66,8 @@ var backupsScheduler = &BackupsScheduler{
logger: logger.GetLogger(),
backupToNodeRelations: make(map[uuid.UUID]BackupToNodeRelation),
backuperNode: backuperNode,
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
func GetBackupsScheduler() *BackupsScheduler {

View File

@@ -6,6 +6,8 @@ import (
"fmt"
"log/slog"
"strings"
"sync"
"sync/atomic"
"time"
cache_utils "databasus-backend/internal/util/cache"
@@ -47,24 +49,37 @@ type BackupNodesRegistry struct {
timeout time.Duration
pubsubBackups *cache_utils.PubSubManager
pubsubCompletions *cache_utils.PubSubManager
runOnce sync.Once
hasRun atomic.Bool
}
func (r *BackupNodesRegistry) Run(ctx context.Context) {
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
}
wasAlreadyRun := r.hasRun.Load()
ticker := time.NewTicker(cleanupTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes", "error", err)
r.runOnce.Do(func() {
r.hasRun.Store(true)
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
}
ticker := time.NewTicker(cleanupTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes", "error", err)
}
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", r))
}
}

View File

@@ -4,6 +4,8 @@ import (
"context"
"encoding/json"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
@@ -594,11 +596,13 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) {
func createTestRegistry() *BackupNodesRegistry {
return &BackupNodesRegistry{
cache_utils.GetValkeyClient(),
logger.GetLogger(),
cache_utils.DefaultCacheTimeout,
cache_utils.NewPubSubManager(),
cache_utils.NewPubSubManager(),
client: cache_utils.GetValkeyClient(),
logger: logger.GetLogger(),
timeout: cache_utils.DefaultCacheTimeout,
pubsubBackups: cache_utils.NewPubSubManager(),
pubsubCompletions: cache_utils.NewPubSubManager(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}

View File

@@ -2,6 +2,14 @@ package backuping
import (
"context"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"databasus-backend/internal/config"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
@@ -9,11 +17,6 @@ import (
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/period"
"fmt"
"log/slog"
"time"
"github.com/google/uuid"
)
const (
@@ -34,59 +37,72 @@ type BackupsScheduler struct {
backupToNodeRelations map[uuid.UUID]BackupToNodeRelation
backuperNode *BackuperNode
runOnce sync.Once
hasRun atomic.Bool
}
func (s *BackupsScheduler) Run(ctx context.Context) {
s.lastBackupTime = time.Now().UTC()
wasAlreadyRun := s.hasRun.Load()
if config.GetEnv().IsManyNodesMode {
// wait other nodes to start
time.Sleep(schedulerStartupDelay)
}
s.runOnce.Do(func() {
s.hasRun.Store(true)
if err := s.failBackupsInProgress(); err != nil {
s.logger.Error("Failed to fail backups in progress", "error", err)
panic(err)
}
s.lastBackupTime = time.Now().UTC()
err := s.backupNodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted)
if err != nil {
s.logger.Error("Failed to subscribe to backup completions", "error", err)
panic(err)
}
defer func() {
if err := s.backupNodesRegistry.UnsubscribeForBackupsCompletions(); err != nil {
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
if config.GetEnv().IsManyNodesMode {
// wait other nodes to start
time.Sleep(schedulerStartupDelay)
}
}()
if ctx.Err() != nil {
return
}
if err := s.failBackupsInProgress(); err != nil {
s.logger.Error("Failed to fail backups in progress", "error", err)
panic(err)
}
ticker := time.NewTicker(schedulerTickerInterval)
defer ticker.Stop()
err := s.backupNodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted)
if err != nil {
s.logger.Error("Failed to subscribe to backup completions", "error", err)
panic(err)
}
for {
select {
case <-ctx.Done():
defer func() {
if err := s.backupNodesRegistry.UnsubscribeForBackupsCompletions(); err != nil {
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
}
}()
if ctx.Err() != nil {
return
case <-ticker.C:
if err := s.cleanOldBackups(); err != nil {
s.logger.Error("Failed to clean old backups", "error", err)
}
if err := s.checkDeadNodesAndFailBackups(); err != nil {
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
}
if err := s.runPendingBackups(); err != nil {
s.logger.Error("Failed to run pending backups", "error", err)
}
s.lastBackupTime = time.Now().UTC()
}
ticker := time.NewTicker(schedulerTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.cleanOldBackups(); err != nil {
s.logger.Error("Failed to clean old backups", "error", err)
}
if err := s.checkDeadNodesAndFailBackups(); err != nil {
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
}
if err := s.runPendingBackups(); err != nil {
s.logger.Error("Failed to run pending backups", "error", err)
}
s.lastBackupTime = time.Now().UTC()
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
}

View File

@@ -835,7 +835,8 @@ func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T
cache_utils.ClearAllCache()
// Start scheduler so it can handle task completions
schedulerCancel := StartSchedulerForTest(t)
scheduler := CreateTestScheduler()
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
backuperNode := CreateTestBackuperNode()
@@ -891,7 +892,7 @@ func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T
t.Logf("Initial active tasks: %d", initialActiveTasks)
// Start backup
GetBackupsScheduler().StartBackup(database.ID, false)
scheduler.StartBackup(database.ID, false)
// Wait for backup to complete
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
@@ -930,7 +931,8 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
cache_utils.ClearAllCache()
// Start scheduler so it can handle task completions
schedulerCancel := StartSchedulerForTest(t)
scheduler := CreateTestScheduler()
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
backuperNode := CreateTestBackuperNode()
@@ -993,7 +995,7 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
t.Logf("Initial active tasks: %d", initialActiveTasks)
// Start backup
GetBackupsScheduler().StartBackup(database.ID, false)
scheduler.StartBackup(database.ID, false)
// Wait for backup to fail
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)

View File

@@ -3,6 +3,8 @@ package backuping
import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
@@ -35,19 +37,37 @@ func CreateTestRouter() *gin.Engine {
func CreateTestBackuperNode() *BackuperNode {
return &BackuperNode{
databases.GetDatabaseService(),
encryption.GetFieldEncryptor(),
workspaces_services.GetWorkspaceService(),
databaseService: databases.GetDatabaseService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
workspaceService: workspaces_services.GetWorkspaceService(),
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
notificationSender: notifiers.GetNotifierService(),
backupCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
logger: logger.GetLogger(),
createBackupUseCase: usecases.GetCreateBackupUsecase(),
nodeID: uuid.New(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}
func CreateTestScheduler() *BackupsScheduler {
return &BackupsScheduler{
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
notifiers.GetNotifierService(),
taskCancelManager,
backupNodesRegistry,
time.Now().UTC(),
logger.GetLogger(),
usecases.GetCreateBackupUsecase(),
uuid.New(),
time.Time{},
make(map[uuid.UUID]BackupToNodeRelation),
CreateTestBackuperNode(),
sync.Once{},
atomic.Bool{},
}
}
@@ -141,13 +161,13 @@ func StartBackuperNodeForTest(t *testing.T, backuperNode *BackuperNode) context.
// StartSchedulerForTest starts the BackupsScheduler in a goroutine for testing.
// The scheduler subscribes to task completions and manages backup lifecycle.
// Returns a context cancel function that should be deferred to stop the scheduler.
func StartSchedulerForTest(t *testing.T) context.CancelFunc {
func StartSchedulerForTest(t *testing.T, scheduler *BackupsScheduler) context.CancelFunc {
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
GetBackupsScheduler().Run(ctx)
scheduler.Run(ctx)
close(done)
}()

View File

@@ -1,6 +1,9 @@
package backups
import (
"sync"
"sync/atomic"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups/backuping"
backups_core "databasus-backend/internal/features/backups/backups/core"
@@ -52,11 +55,26 @@ func GetBackupController() *BackupController {
return backupController
}
func SetupDependencies() {
backups_config.
GetBackupConfigService().
SetDatabaseStorageChangeListener(backupService)
var (
setupOnce sync.Once
isSetup atomic.Bool
)
databases.GetDatabaseService().AddDbRemoveListener(backupService)
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
func SetupDependencies() {
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
backups_config.
GetBackupConfigService().
SetDatabaseStorageChangeListener(backupService)
databases.GetDatabaseService().AddDbRemoveListener(backupService)
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -2,33 +2,49 @@ package backups_download
import (
"context"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
)
type DownloadTokenBackgroundService struct {
downloadTokenService *DownloadTokenService
logger *slog.Logger
runOnce sync.Once
hasRun atomic.Bool
}
func (s *DownloadTokenBackgroundService) Run(ctx context.Context) {
s.logger.Info("Starting download token cleanup background service")
wasAlreadyRun := s.hasRun.Load()
if ctx.Err() != nil {
return
}
s.runOnce.Do(func() {
s.hasRun.Store(true)
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
s.logger.Info("Starting download token cleanup background service")
for {
select {
case <-ctx.Done():
if ctx.Err() != nil {
return
case <-ticker.C:
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
s.logger.Error("Failed to clean expired download tokens", "error", err)
}
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
s.logger.Error("Failed to clean expired download tokens", "error", err)
}
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
}

View File

@@ -1,6 +1,9 @@
package backups_download
import (
"sync"
"sync/atomic"
"databasus-backend/internal/config"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/logger"
@@ -30,8 +33,10 @@ func init() {
}
downloadTokenBackgroundService = &DownloadTokenBackgroundService{
downloadTokenService,
logger.GetLogger(),
downloadTokenService: downloadTokenService,
logger: logger.GetLogger(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}

View File

@@ -1,10 +1,14 @@
package backups_config
import (
"sync"
"sync/atomic"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/logger"
)
var backupConfigRepository = &BackupConfigRepository{}
@@ -28,6 +32,21 @@ func GetBackupConfigService() *BackupConfigService {
return backupConfigService
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
storages.GetStorageService().SetStorageDatabaseCounter(backupConfigService)
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
storages.GetStorageService().SetStorageDatabaseCounter(backupConfigService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -1,6 +1,9 @@
package databases
import (
"sync"
"sync/atomic"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/notifiers"
users_services "databasus-backend/internal/features/users/services"
@@ -37,7 +40,22 @@ func GetDatabaseController() *DatabaseController {
return databaseController
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(databaseService)
notifiers.GetNotifierService().SetNotifierDatabaseCounter(databaseService)
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(databaseService)
notifiers.GetNotifierService().SetNotifierDatabaseCounter(databaseService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -2,30 +2,47 @@ package healthcheck_attempt
import (
"context"
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
)
type HealthcheckAttemptBackgroundService struct {
healthcheckConfigService *healthcheck_config.HealthcheckConfigService
checkDatabaseHealthUseCase *CheckDatabaseHealthUseCase
logger *slog.Logger
runOnce sync.Once
hasRun atomic.Bool
}
func (s *HealthcheckAttemptBackgroundService) Run(ctx context.Context) {
// first healthcheck immediately
s.checkDatabases()
wasAlreadyRun := s.hasRun.Load()
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.checkDatabases()
s.runOnce.Do(func() {
s.hasRun.Store(true)
// first healthcheck immediately
s.checkDatabases()
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.checkDatabases()
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
}

View File

@@ -1,6 +1,9 @@
package healthcheck_attempt
import (
"sync"
"sync/atomic"
"databasus-backend/internal/features/databases"
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
"databasus-backend/internal/features/notifiers"
@@ -22,9 +25,11 @@ var checkDatabaseHealthUseCase = &CheckDatabaseHealthUseCase{
}
var healthcheckAttemptBackgroundService = &HealthcheckAttemptBackgroundService{
healthcheck_config.GetHealthcheckConfigService(),
checkDatabaseHealthUseCase,
logger.GetLogger(),
healthcheckConfigService: healthcheck_config.GetHealthcheckConfigService(),
checkDatabaseHealthUseCase: checkDatabaseHealthUseCase,
logger: logger.GetLogger(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
var healthcheckAttemptController = &HealthcheckAttemptController{
healthcheckAttemptService,

View File

@@ -1,6 +1,9 @@
package healthcheck_config
import (
"sync"
"sync/atomic"
"databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/databases"
workspaces_services "databasus-backend/internal/features/workspaces/services"
@@ -27,8 +30,23 @@ func GetHealthcheckConfigController() *HealthcheckConfigController {
return healthcheckConfigController
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
databases.
GetDatabaseService().
AddDbCreationListener(healthcheckConfigService)
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
databases.
GetDatabaseService().
AddDbCreationListener(healthcheckConfigService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -1,6 +1,9 @@
package notifiers
import (
"sync"
"sync/atomic"
audit_logs "databasus-backend/internal/features/audit_logs"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/encryption"
@@ -32,6 +35,22 @@ func GetNotifierService() *NotifierService {
func GetNotifierRepository() *NotifierRepository {
return notifierRepository
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(notifierService)
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(notifierService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -16,6 +16,7 @@ type RestoreController struct {
func (c *RestoreController) RegisterRoutes(router *gin.RouterGroup) {
router.GET("/restores/:backupId", c.GetRestores)
router.POST("/restores/:backupId/restore", c.RestoreBackup)
router.POST("/restores/cancel/:restoreId", c.CancelRestore)
}
// GetRestores
@@ -85,3 +86,33 @@ func (c *RestoreController) RestoreBackup(ctx *gin.Context) {
ctx.JSON(http.StatusOK, gin.H{"message": "restore started successfully"})
}
// CancelRestore
// @Summary Cancel an in-progress restore
// @Description Cancel a restore that is currently in progress
// @Tags restores
// @Param restoreId path string true "Restore ID"
// @Success 204
// @Failure 400
// @Failure 401
// @Router /restores/cancel/{restoreId} [post]
func (c *RestoreController) CancelRestore(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
restoreID, err := uuid.Parse(ctx.Param("restoreId"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid restore ID"})
return
}
if err := c.restoreService.CancelRestore(user, restoreID); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.Status(http.StatusNoContent)
}

View File

@@ -18,20 +18,25 @@ import (
"databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/mysql"
"databasus-backend/internal/features/databases/databases/postgresql"
"databasus-backend/internal/features/notifiers"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/restores/restoring"
"databasus-backend/internal/features/storages"
local_storage "databasus-backend/internal/features/storages/models/local"
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
users_dto "databasus-backend/internal/features/users/dto"
users_enums "databasus-backend/internal/features/users/enums"
users_services "databasus-backend/internal/features/users/services"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_models "databasus-backend/internal/features/workspaces/models"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
cache_utils "databasus-backend/internal/util/cache"
util_encryption "databasus-backend/internal/util/encryption"
test_utils "databasus-backend/internal/util/testing"
"databasus-backend/internal/util/tools"
@@ -370,6 +375,142 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
}
}
func Test_CancelRestore_InProgressRestore_SuccessfullyCancelled(t *testing.T) {
cache_utils.ClearAllCache()
tasks_cancellation.SetupDependencies()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := createTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backupRepo := backups_core.BackupRepository{}
backups, _ := backupRepo.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepo.DeleteByID(backup.ID)
}
restoreRepo := restores_core.RestoreRepository{}
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
for _, restore := range restores {
restoreRepo.DeleteByID(restore.ID)
}
restores, _ = restoreRepo.FindByStatus(restores_core.RestoreStatusCanceled)
for _, restore := range restores {
restoreRepo.DeleteByID(restore.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
cache_utils.ClearAllCache()
}()
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
backup := backups.CreateTestBackup(database.ID, storage.ID)
mockUsecase := &restoring.MockBlockingRestoreUsecase{
StartedChan: make(chan bool, 1),
}
restorerNode := restoring.CreateTestRestorerNodeWithUsecase(mockUsecase)
cancelNode := restoring.StartRestorerNodeForTest(t, restorerNode)
defer cancelNode()
time.Sleep(200 * time.Millisecond)
restoreRequest := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
},
}
var restoreResponse map[string]interface{}
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
"Bearer "+user.Token,
restoreRequest,
http.StatusOK,
&restoreResponse,
)
select {
case <-mockUsecase.StartedChan:
t.Log("Restore started and is blocking")
case <-time.After(2 * time.Second):
t.Fatal("Restore did not start within timeout")
}
restoreRepo := &restores_core.RestoreRepository{}
restores, err := restoreRepo.FindByBackupID(backup.ID)
assert.NoError(t, err)
assert.Greater(t, len(restores), 0, "At least one restore should exist")
var restoreID uuid.UUID
for _, r := range restores {
if r.Status == restores_core.RestoreStatusInProgress {
restoreID = r.ID
break
}
}
assert.NotEqual(t, uuid.Nil, restoreID, "Should find an in-progress restore")
resp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/cancel/%s", restoreID.String()),
"Bearer "+user.Token,
nil,
http.StatusNoContent,
)
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
deadline := time.Now().UTC().Add(3 * time.Second)
var restore *restores_core.Restore
for time.Now().UTC().Before(deadline) {
restore, err = restoreRepo.FindByID(restoreID)
assert.NoError(t, err)
if restore.Status == restores_core.RestoreStatusCanceled {
break
}
time.Sleep(100 * time.Millisecond)
}
assert.Equal(t, restores_core.RestoreStatusCanceled, restore.Status)
auditLogService := audit_logs.GetAuditLogService()
auditLogs, err := auditLogService.GetWorkspaceAuditLogs(
workspace.ID,
&audit_logs.GetAuditLogsRequest{Limit: 100, Offset: 0},
)
assert.NoError(t, err)
foundCancelLog := false
for _, log := range auditLogs.AuditLogs {
if strings.Contains(log.Message, "Restore cancelled") &&
strings.Contains(log.Message, database.Name) {
foundCancelLog = true
break
}
}
assert.True(t, foundCancelLog, "Cancel audit log should be created")
time.Sleep(200 * time.Millisecond)
}
func createTestRouter() *gin.Engine {
return CreateTestRouter()
}

View File

@@ -6,4 +6,5 @@ const (
RestoreStatusInProgress RestoreStatus = "IN_PROGRESS"
RestoreStatusCompleted RestoreStatus = "COMPLETED"
RestoreStatusFailed RestoreStatus = "FAILED"
RestoreStatusCanceled RestoreStatus = "CANCELED"
)

View File

@@ -1,6 +1,8 @@
package restores_core
import (
"context"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -9,6 +11,7 @@ import (
type RestoreBackupUsecase interface {
Execute(
ctx context.Context,
backupConfig *backups_config.BackupConfig,
restore Restore,
originalDB *databases.Database,

View File

@@ -1,6 +1,9 @@
package restores
import (
"sync"
"sync/atomic"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
backups_config "databasus-backend/internal/features/backups/config"
@@ -9,6 +12,7 @@ import (
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/restores/usecases"
"databasus-backend/internal/features/storages"
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
@@ -27,6 +31,7 @@ var restoreService = &RestoreService{
audit_logs.GetAuditLogService(),
encryption.GetFieldEncryptor(),
disk.GetDiskService(),
tasks_cancellation.GetTaskCancelManager(),
}
var restoreController = &RestoreController{
restoreService,
@@ -36,6 +41,21 @@ func GetRestoreController() *RestoreController {
return restoreController
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
backups.GetBackupService().AddBackupRemoveListener(restoreService)
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
backups.GetBackupService().AddBackupRemoveListener(restoreService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -1,6 +1,8 @@
package restoring
import (
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
@@ -11,6 +13,7 @@ import (
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/restores/usecases"
"databasus-backend/internal/features/storages"
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
@@ -19,11 +22,13 @@ import (
var restoreRepository = &restores_core.RestoreRepository{}
var restoreNodesRegistry = &RestoreNodesRegistry{
cache_utils.GetValkeyClient(),
logger.GetLogger(),
cache_utils.DefaultCacheTimeout,
cache_utils.NewPubSubManager(),
cache_utils.NewPubSubManager(),
client: cache_utils.GetValkeyClient(),
logger: logger.GetLogger(),
timeout: cache_utils.DefaultCacheTimeout,
pubsubRestores: cache_utils.NewPubSubManager(),
pubsubCompletions: cache_utils.NewPubSubManager(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
var restoreDatabaseCache = cache_utils.NewCacheUtil[RestoreDatabaseCache](
@@ -31,19 +36,24 @@ var restoreDatabaseCache = cache_utils.NewCacheUtil[RestoreDatabaseCache](
"restore_db:",
)
var restoreCancelManager = tasks_cancellation.GetTaskCancelManager()
var restorerNode = &RestorerNode{
uuid.New(),
databases.GetDatabaseService(),
backups.GetBackupService(),
encryption.GetFieldEncryptor(),
restoreRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
restoreNodesRegistry,
logger.GetLogger(),
usecases.GetRestoreBackupUsecase(),
restoreDatabaseCache,
time.Time{},
nodeID: uuid.New(),
databaseService: databases.GetDatabaseService(),
backupService: backups.GetBackupService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
restoreRepository: restoreRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
restoreNodesRegistry: restoreNodesRegistry,
logger: logger.GetLogger(),
restoreBackupUsecase: usecases.GetRestoreBackupUsecase(),
cacheUtil: restoreDatabaseCache,
restoreCancelManager: restoreCancelManager,
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
var restoresScheduler = &RestoresScheduler{
@@ -58,6 +68,8 @@ var restoresScheduler = &RestoresScheduler{
restorerNode: restorerNode,
cacheUtil: restoreDatabaseCache,
completionSubscriptionID: uuid.Nil,
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
func GetRestoresScheduler() *RestoresScheduler {

View File

@@ -1,6 +1,7 @@
package restoring
import (
"context"
"errors"
backups_core "databasus-backend/internal/features/backups/backups/core"
@@ -13,6 +14,7 @@ import (
type MockSuccessRestoreUsecase struct{}
func (uc *MockSuccessRestoreUsecase) Execute(
ctx context.Context,
backupConfig *backups_config.BackupConfig,
restore restores_core.Restore,
originalDB *databases.Database,
@@ -27,6 +29,7 @@ func (uc *MockSuccessRestoreUsecase) Execute(
type MockFailedRestoreUsecase struct{}
func (uc *MockFailedRestoreUsecase) Execute(
ctx context.Context,
backupConfig *backups_config.BackupConfig,
restore restores_core.Restore,
originalDB *databases.Database,
@@ -44,6 +47,7 @@ type MockCaptureCredentialsRestoreUsecase struct {
}
func (uc *MockCaptureCredentialsRestoreUsecase) Execute(
ctx context.Context,
backupConfig *backups_config.BackupConfig,
restore restores_core.Restore,
originalDB *databases.Database,
@@ -59,3 +63,26 @@ func (uc *MockCaptureCredentialsRestoreUsecase) Execute(
}
return errors.New("mock restore failed")
}
type MockBlockingRestoreUsecase struct {
StartedChan chan bool
}
func (uc *MockBlockingRestoreUsecase) Execute(
ctx context.Context,
backupConfig *backups_config.BackupConfig,
restore restores_core.Restore,
originalDB *databases.Database,
restoringToDB *databases.Database,
backup *backups_core.Backup,
storage *storages.Storage,
isExcludeExtensions bool,
) error {
if uc.StartedChan != nil {
uc.StartedChan <- true
}
<-ctx.Done()
return ctx.Err()
}

View File

@@ -6,6 +6,8 @@ import (
"fmt"
"log/slog"
"strings"
"sync"
"sync/atomic"
"time"
cache_utils "databasus-backend/internal/util/cache"
@@ -47,24 +49,37 @@ type RestoreNodesRegistry struct {
timeout time.Duration
pubsubRestores *cache_utils.PubSubManager
pubsubCompletions *cache_utils.PubSubManager
runOnce sync.Once
hasRun atomic.Bool
}
func (r *RestoreNodesRegistry) Run(ctx context.Context) {
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
}
wasAlreadyRun := r.hasRun.Load()
ticker := time.NewTicker(cleanupTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes", "error", err)
r.runOnce.Do(func() {
r.hasRun.Store(true)
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
}
ticker := time.NewTicker(cleanupTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes", "error", err)
}
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", r))
}
}

View File

@@ -4,6 +4,8 @@ import (
"context"
"encoding/json"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
@@ -594,11 +596,13 @@ func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) {
func createTestRegistry() *RestoreNodesRegistry {
return &RestoreNodesRegistry{
cache_utils.GetValkeyClient(),
logger.GetLogger(),
cache_utils.DefaultCacheTimeout,
cache_utils.NewPubSubManager(),
cache_utils.NewPubSubManager(),
client: cache_utils.GetValkeyClient(),
logger: logger.GetLogger(),
timeout: cache_utils.DefaultCacheTimeout,
pubsubRestores: cache_utils.NewPubSubManager(),
pubsubCompletions: cache_utils.NewPubSubManager(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}

View File

@@ -2,8 +2,12 @@ package restoring
import (
"context"
"errors"
"fmt"
"log/slog"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
@@ -14,6 +18,7 @@ import (
"databasus-backend/internal/features/databases"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
cache_utils "databasus-backend/internal/util/cache"
util_encryption "databasus-backend/internal/util/encryption"
)
@@ -36,70 +41,84 @@ type RestorerNode struct {
logger *slog.Logger
restoreBackupUsecase restores_core.RestoreBackupUsecase
cacheUtil *cache_utils.CacheUtil[RestoreDatabaseCache]
restoreCancelManager *tasks_cancellation.TaskCancelManager
lastHeartbeat time.Time
runOnce sync.Once
hasRun atomic.Bool
}
func (n *RestorerNode) Run(ctx context.Context) {
n.lastHeartbeat = time.Now().UTC()
wasAlreadyRun := n.hasRun.Load()
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
n.runOnce.Do(func() {
n.hasRun.Store(true)
restoreNode := RestoreNode{
ID: n.nodeID,
ThroughputMBs: throughputMBs,
}
n.lastHeartbeat = time.Now().UTC()
if err := n.restoreNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), restoreNode); err != nil {
n.logger.Error("Failed to register node in registry", "error", err)
panic(err)
}
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
restoreHandler := func(restoreID uuid.UUID, isCallNotifier bool) {
n.MakeRestore(restoreID)
if err := n.restoreNodesRegistry.PublishRestoreCompletion(n.nodeID, restoreID); err != nil {
n.logger.Error(
"Failed to publish restore completion",
"error",
err,
"restoreID",
restoreID,
)
restoreNode := RestoreNode{
ID: n.nodeID,
ThroughputMBs: throughputMBs,
}
}
err := n.restoreNodesRegistry.SubscribeNodeForRestoresAssignment(
n.nodeID,
restoreHandler,
)
if err != nil {
n.logger.Error("Failed to subscribe to restore assignments", "error", err)
panic(err)
}
defer func() {
if err := n.restoreNodesRegistry.UnsubscribeNodeForRestoresAssignments(); err != nil {
n.logger.Error("Failed to unsubscribe from restore assignments", "error", err)
if err := n.restoreNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), restoreNode); err != nil {
n.logger.Error("Failed to register node in registry", "error", err)
panic(err)
}
}()
ticker := time.NewTicker(heartbeatTickerInterval)
defer ticker.Stop()
n.logger.Info("Restore node started", "nodeID", n.nodeID, "throughput", throughputMBs)
for {
select {
case <-ctx.Done():
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
if err := n.restoreNodesRegistry.UnregisterNodeFromRegistry(restoreNode); err != nil {
n.logger.Error("Failed to unregister node from registry", "error", err)
restoreHandler := func(restoreID uuid.UUID, isCallNotifier bool) {
n.MakeRestore(restoreID)
if err := n.restoreNodesRegistry.PublishRestoreCompletion(n.nodeID, restoreID); err != nil {
n.logger.Error(
"Failed to publish restore completion",
"error",
err,
"restoreID",
restoreID,
)
}
return
case <-ticker.C:
n.sendHeartbeat(&restoreNode)
}
err := n.restoreNodesRegistry.SubscribeNodeForRestoresAssignment(
n.nodeID,
restoreHandler,
)
if err != nil {
n.logger.Error("Failed to subscribe to restore assignments", "error", err)
panic(err)
}
defer func() {
if err := n.restoreNodesRegistry.UnsubscribeNodeForRestoresAssignments(); err != nil {
n.logger.Error("Failed to unsubscribe from restore assignments", "error", err)
}
}()
ticker := time.NewTicker(heartbeatTickerInterval)
defer ticker.Stop()
n.logger.Info("Restore node started", "nodeID", n.nodeID, "throughput", throughputMBs)
for {
select {
case <-ctx.Done():
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
if err := n.restoreNodesRegistry.UnregisterNodeFromRegistry(restoreNode); err != nil {
n.logger.Error("Failed to unregister node from registry", "error", err)
}
return
case <-ticker.C:
n.sendHeartbeat(&restoreNode)
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", n))
}
}
@@ -176,6 +195,11 @@ func (n *RestorerNode) MakeRestore(restoreID uuid.UUID) {
start := time.Now().UTC()
// Create cancellable context
ctx, cancel := context.WithCancel(context.Background())
n.restoreCancelManager.RegisterTask(restore.ID, cancel)
defer n.restoreCancelManager.UnregisterTask(restore.ID)
// Create restoring database from cached credentials
restoringToDB := &databases.Database{
Type: database.Type,
@@ -204,6 +228,7 @@ func (n *RestorerNode) MakeRestore(restoreID uuid.UUID) {
}
err = n.restoreBackupUsecase.Execute(
ctx,
backupConfig,
*restore,
database,
@@ -216,6 +241,29 @@ func (n *RestorerNode) MakeRestore(restoreID uuid.UUID) {
if err != nil {
errMsg := err.Error()
// Check if restore was cancelled
isCancelled := strings.Contains(errMsg, "restore cancelled") ||
strings.Contains(errMsg, "context canceled") ||
errors.Is(err, context.Canceled)
isShutdown := strings.Contains(errMsg, "shutdown")
if isCancelled && !isShutdown {
n.logger.Warn("Restore was cancelled by user or system",
"restoreId", restore.ID,
"isCancelled", isCancelled,
"isShutdown", isShutdown,
)
restore.Status = restores_core.RestoreStatusCanceled
restore.RestoreDurationMs = time.Since(start).Milliseconds()
if err := n.restoreRepository.Save(restore); err != nil {
n.logger.Error("Failed to save cancelled restore", "error", err)
}
return
}
n.logger.Error("Restore execution failed",
"restoreId", restore.ID,
"backupId", backup.ID,

View File

@@ -4,6 +4,8 @@ import (
"context"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
@@ -34,51 +36,64 @@ type RestoresScheduler struct {
restorerNode *RestorerNode
cacheUtil *cache_utils.CacheUtil[RestoreDatabaseCache]
completionSubscriptionID uuid.UUID
runOnce sync.Once
hasRun atomic.Bool
}
func (s *RestoresScheduler) Run(ctx context.Context) {
s.lastCheckTime = time.Now().UTC()
wasAlreadyRun := s.hasRun.Load()
if config.GetEnv().IsManyNodesMode {
// wait other nodes to start
time.Sleep(schedulerStartupDelay)
}
s.runOnce.Do(func() {
s.hasRun.Store(true)
if err := s.failRestoresInProgress(); err != nil {
s.logger.Error("Failed to fail restores in progress", "error", err)
panic(err)
}
s.lastCheckTime = time.Now().UTC()
err := s.restoreNodesRegistry.SubscribeForRestoresCompletions(s.onRestoreCompleted)
if err != nil {
s.logger.Error("Failed to subscribe to restore completions", "error", err)
panic(err)
}
defer func() {
if err := s.restoreNodesRegistry.UnsubscribeForRestoresCompletions(); err != nil {
s.logger.Error("Failed to unsubscribe from restore completions", "error", err)
if config.GetEnv().IsManyNodesMode {
// wait other nodes to start
time.Sleep(schedulerStartupDelay)
}
}()
if ctx.Err() != nil {
return
}
if err := s.failRestoresInProgress(); err != nil {
s.logger.Error("Failed to fail restores in progress", "error", err)
panic(err)
}
ticker := time.NewTicker(schedulerTickerInterval)
defer ticker.Stop()
err := s.restoreNodesRegistry.SubscribeForRestoresCompletions(s.onRestoreCompleted)
if err != nil {
s.logger.Error("Failed to subscribe to restore completions", "error", err)
panic(err)
}
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.checkDeadNodesAndFailRestores(); err != nil {
s.logger.Error("Failed to check dead nodes and fail restores", "error", err)
defer func() {
if err := s.restoreNodesRegistry.UnsubscribeForRestoresCompletions(); err != nil {
s.logger.Error("Failed to unsubscribe from restore completions", "error", err)
}
}()
s.lastCheckTime = time.Now().UTC()
if ctx.Err() != nil {
return
}
ticker := time.NewTicker(schedulerTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.checkDeadNodesAndFailRestores(); err != nil {
s.logger.Error("Failed to check dead nodes and fail restores", "error", err)
}
s.lastCheckTime = time.Now().UTC()
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
}

View File

@@ -424,7 +424,8 @@ func Test_StartRestore_RestoreCompletes_DecrementsActiveTaskCount(t *testing.T)
cache_utils.ClearAllCache()
// Start scheduler so it can handle task completions
schedulerCancel := StartSchedulerForTest(t)
scheduler := CreateTestRestoresScheduler()
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
restorerNode := CreateTestRestorerNode()
@@ -485,7 +486,7 @@ func Test_StartRestore_RestoreCompletes_DecrementsActiveTaskCount(t *testing.T)
err = restoreRepository.Save(restore)
assert.NoError(t, err)
err = GetRestoresScheduler().StartRestore(restore.ID, nil)
err = scheduler.StartRestore(restore.ID, nil)
assert.NoError(t, err)
// Wait for restore to complete
@@ -524,7 +525,8 @@ func Test_StartRestore_RestoreFails_DecrementsActiveTaskCount(t *testing.T) {
cache_utils.ClearAllCache()
// Start scheduler so it can handle task completions
schedulerCancel := StartSchedulerForTest(t)
scheduler := CreateTestRestoresScheduler()
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
restorerNode := CreateTestRestorerNode()
@@ -585,7 +587,7 @@ func Test_StartRestore_RestoreFails_DecrementsActiveTaskCount(t *testing.T) {
err = restoreRepository.Save(restore)
assert.NoError(t, err)
err = GetRestoresScheduler().StartRestore(restore.ID, nil)
err = scheduler.StartRestore(restore.ID, nil)
assert.NoError(t, err)
// Wait for restore to fail
@@ -729,7 +731,8 @@ func Test_StartRestore_CredentialsRemovedAfterRestoreStarts(t *testing.T) {
cache_utils.ClearAllCache()
// Start scheduler so it can handle task assignments
schedulerCancel := StartSchedulerForTest(t)
scheduler := CreateTestRestoresScheduler()
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
// Create mock restorer node with credential capture usecase
@@ -810,7 +813,7 @@ func Test_StartRestore_CredentialsRemovedAfterRestoreStarts(t *testing.T) {
}
// Call StartRestore to cache credentials and trigger restore
err = GetRestoresScheduler().StartRestore(restore.ID, dbCache)
err = scheduler.StartRestore(restore.ID, dbCache)
assert.NoError(t, err)
// Wait for mock usecase to be called (with timeout)

View File

@@ -3,6 +3,8 @@ package restoring
import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
@@ -17,6 +19,7 @@ import (
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/restores/usecases"
"databasus-backend/internal/features/storages"
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/util/encryption"
@@ -36,18 +39,59 @@ func CreateTestRouter() *gin.Engine {
func CreateTestRestorerNode() *RestorerNode {
return &RestorerNode{
uuid.New(),
databases.GetDatabaseService(),
backups.GetBackupService(),
encryption.GetFieldEncryptor(),
nodeID: uuid.New(),
databaseService: databases.GetDatabaseService(),
backupService: backups.GetBackupService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
restoreRepository: restoreRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
restoreNodesRegistry: restoreNodesRegistry,
logger: logger.GetLogger(),
restoreBackupUsecase: usecases.GetRestoreBackupUsecase(),
cacheUtil: restoreDatabaseCache,
restoreCancelManager: tasks_cancellation.GetTaskCancelManager(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}
func CreateTestRestorerNodeWithUsecase(usecase restores_core.RestoreBackupUsecase) *RestorerNode {
return &RestorerNode{
nodeID: uuid.New(),
databaseService: databases.GetDatabaseService(),
backupService: backups.GetBackupService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
restoreRepository: restoreRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
restoreNodesRegistry: restoreNodesRegistry,
logger: logger.GetLogger(),
restoreBackupUsecase: usecase,
cacheUtil: restoreDatabaseCache,
restoreCancelManager: tasks_cancellation.GetTaskCancelManager(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}
func CreateTestRestoresScheduler() *RestoresScheduler {
return &RestoresScheduler{
restoreRepository,
backups_config.GetBackupConfigService(),
backups.GetBackupService(),
storages.GetStorageService(),
backups_config.GetBackupConfigService(),
restoreNodesRegistry,
time.Now().UTC(),
logger.GetLogger(),
usecases.GetRestoreBackupUsecase(),
make(map[uuid.UUID]RestoreToNodeRelation),
restorerNode,
restoreDatabaseCache,
time.Time{},
uuid.Nil,
sync.Once{},
atomic.Bool{},
}
}
@@ -128,13 +172,13 @@ func StartRestorerNodeForTest(t *testing.T, restorerNode *RestorerNode) context.
// StartSchedulerForTest starts the RestoresScheduler in a goroutine for testing.
// The scheduler subscribes to task completions and manages restore lifecycle.
// Returns a context cancel function that should be deferred to stop the scheduler.
func StartSchedulerForTest(t *testing.T) context.CancelFunc {
func StartSchedulerForTest(t *testing.T, scheduler *RestoresScheduler) context.CancelFunc {
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
GetRestoresScheduler().Run(ctx)
scheduler.Run(ctx)
close(done)
}()

View File

@@ -11,6 +11,7 @@ import (
"databasus-backend/internal/features/restores/restoring"
"databasus-backend/internal/features/restores/usecases"
"databasus-backend/internal/features/storages"
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
users_models "databasus-backend/internal/features/users/models"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/encryption"
@@ -35,6 +36,7 @@ type RestoreService struct {
auditLogService *audit_logs.AuditLogService
fieldEncryptor encryption.FieldEncryptor
diskService *disk.DiskService
taskCancelManager *tasks_cancellation.TaskCancelManager
}
func (s *RestoreService) OnBeforeBackupRemove(backup *backups_core.Backup) error {
@@ -340,3 +342,55 @@ func (s *RestoreService) validateDiskSpace(
return nil
}
func (s *RestoreService) CancelRestore(
user *users_models.User,
restoreID uuid.UUID,
) error {
restore, err := s.restoreRepository.FindByID(restoreID)
if err != nil {
return err
}
backup, err := s.backupService.GetBackup(restore.BackupID)
if err != nil {
return err
}
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
if err != nil {
return err
}
if database.WorkspaceID == nil {
return errors.New("cannot cancel restore for database without workspace")
}
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user)
if err != nil {
return err
}
if !canManage {
return errors.New("insufficient permissions to cancel restore for this database")
}
if restore.Status != restores_core.RestoreStatusInProgress {
return errors.New("restore is not in progress")
}
if err := s.taskCancelManager.CancelTask(restoreID); err != nil {
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Restore cancelled for database: %s (ID: %s)",
database.Name,
restoreID.String(),
),
&user.ID,
database.WorkspaceID,
)
return nil
}

View File

@@ -36,6 +36,7 @@ type RestoreMariadbBackupUsecase struct {
}
func (uc *RestoreMariadbBackupUsecase) Execute(
parentCtx context.Context,
originalDB *databases.Database,
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
@@ -79,6 +80,7 @@ func (uc *RestoreMariadbBackupUsecase) Execute(
}
return uc.restoreFromStorage(
parentCtx,
originalDB,
tools.GetMariadbExecutable(
tools.MariadbExecutableMariadb,
@@ -95,6 +97,7 @@ func (uc *RestoreMariadbBackupUsecase) Execute(
}
func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
parentCtx context.Context,
database *databases.Database,
mariadbBin string,
args []string,
@@ -103,7 +106,7 @@ func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
storage *storages.Storage,
mdbConfig *mariadbtypes.MariadbDatabase,
) error {
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
defer cancel()
go func() {
@@ -113,6 +116,9 @@ func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
select {
case <-ctx.Done():
return
case <-parentCtx.Done():
cancel()
return
case <-ticker.C:
if config.IsShouldShutdown() {
cancel()
@@ -213,6 +219,15 @@ func (uc *RestoreMariadbBackupUsecase) executeMariadbRestore(
waitErr := cmd.Wait()
stderrOutput := <-stderrCh
// Check for cancellation
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.Canceled) {
return fmt.Errorf("restore cancelled")
}
default:
}
if config.IsShouldShutdown() {
return fmt.Errorf("restore cancelled due to shutdown")
}

View File

@@ -36,6 +36,7 @@ type RestoreMongodbBackupUsecase struct {
}
func (uc *RestoreMongodbBackupUsecase) Execute(
parentCtx context.Context,
originalDB *databases.Database,
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
@@ -76,6 +77,7 @@ func (uc *RestoreMongodbBackupUsecase) Execute(
args := uc.buildMongorestoreArgs(mdb, decryptedPassword, sourceDatabase)
return uc.restoreFromStorage(
parentCtx,
tools.GetMongodbExecutable(
tools.MongodbExecutableMongorestore,
config.GetEnv().EnvMode,
@@ -122,12 +124,13 @@ func (uc *RestoreMongodbBackupUsecase) buildMongorestoreArgs(
}
func (uc *RestoreMongodbBackupUsecase) restoreFromStorage(
parentCtx context.Context,
mongorestoreBin string,
args []string,
backup *backups_core.Backup,
storage *storages.Storage,
) error {
ctx, cancel := context.WithTimeout(context.Background(), restoreTimeout)
ctx, cancel := context.WithTimeout(parentCtx, restoreTimeout)
defer cancel()
go func() {
@@ -137,6 +140,9 @@ func (uc *RestoreMongodbBackupUsecase) restoreFromStorage(
select {
case <-ctx.Done():
return
case <-parentCtx.Done():
cancel()
return
case <-ticker.C:
if config.IsShouldShutdown() {
cancel()
@@ -218,6 +224,15 @@ func (uc *RestoreMongodbBackupUsecase) executeMongoRestore(
waitErr := cmd.Wait()
stderrOutput := <-stderrCh
// Check for cancellation
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.Canceled) {
return fmt.Errorf("restore cancelled")
}
default:
}
if config.IsShouldShutdown() {
return fmt.Errorf("restore cancelled due to shutdown")
}

View File

@@ -36,6 +36,7 @@ type RestoreMysqlBackupUsecase struct {
}
func (uc *RestoreMysqlBackupUsecase) Execute(
parentCtx context.Context,
originalDB *databases.Database,
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
@@ -78,6 +79,7 @@ func (uc *RestoreMysqlBackupUsecase) Execute(
}
return uc.restoreFromStorage(
parentCtx,
originalDB,
tools.GetMysqlExecutable(
my.Version,
@@ -94,6 +96,7 @@ func (uc *RestoreMysqlBackupUsecase) Execute(
}
func (uc *RestoreMysqlBackupUsecase) restoreFromStorage(
parentCtx context.Context,
database *databases.Database,
mysqlBin string,
args []string,
@@ -102,7 +105,7 @@ func (uc *RestoreMysqlBackupUsecase) restoreFromStorage(
storage *storages.Storage,
myConfig *mysqltypes.MysqlDatabase,
) error {
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
defer cancel()
go func() {
@@ -112,6 +115,9 @@ func (uc *RestoreMysqlBackupUsecase) restoreFromStorage(
select {
case <-ctx.Done():
return
case <-parentCtx.Done():
cancel()
return
case <-ticker.C:
if config.IsShouldShutdown() {
cancel()
@@ -204,6 +210,15 @@ func (uc *RestoreMysqlBackupUsecase) executeMysqlRestore(
waitErr := cmd.Wait()
stderrOutput := <-stderrCh
// Check for cancellation
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.Canceled) {
return fmt.Errorf("restore cancelled")
}
default:
}
if config.IsShouldShutdown() {
return fmt.Errorf("restore cancelled due to shutdown")
}

View File

@@ -35,6 +35,7 @@ type RestorePostgresqlBackupUsecase struct {
}
func (uc *RestorePostgresqlBackupUsecase) Execute(
parentCtx context.Context,
originalDB *databases.Database,
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
@@ -73,6 +74,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
// All PostgreSQL backups are now custom format (-Fc)
return uc.restoreCustomType(
parentCtx,
originalDB,
pgBin,
backup,
@@ -84,6 +86,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
// restoreCustomType restores a backup in custom type (-Fc)
func (uc *RestorePostgresqlBackupUsecase) restoreCustomType(
parentCtx context.Context,
originalDB *databases.Database,
pgBin string,
backup *backups_core.Backup,
@@ -102,15 +105,24 @@ func (uc *RestorePostgresqlBackupUsecase) restoreCustomType(
// If excluding extensions, we must use file-based restore (requires TOC file generation)
// Also use file-based restore for parallel jobs (multiple CPUs)
if isExcludeExtensions || pg.CpuCount > 1 {
return uc.restoreViaFile(originalDB, pgBin, backup, storage, pg, isExcludeExtensions)
return uc.restoreViaFile(
parentCtx,
originalDB,
pgBin,
backup,
storage,
pg,
isExcludeExtensions,
)
}
// Single CPU without extension exclusion: stream directly via stdin
return uc.restoreViaStdin(originalDB, pgBin, backup, storage, pg)
return uc.restoreViaStdin(parentCtx, originalDB, pgBin, backup, storage, pg)
}
// restoreViaStdin streams backup via stdin for single CPU restore
func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
parentCtx context.Context,
originalDB *databases.Database,
pgBin string,
backup *backups_core.Backup,
@@ -133,10 +145,10 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
"--no-acl",
}
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
defer cancel()
// Monitor for shutdown and cancel context if needed
// Monitor for shutdown and parent cancellation
go func() {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
@@ -145,6 +157,9 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
select {
case <-ctx.Done():
return
case <-parentCtx.Done():
cancel()
return
case <-ticker.C:
if config.IsShouldShutdown() {
cancel()
@@ -296,6 +311,15 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
stderrOutput := <-stderrCh
copyErr := <-copyErrCh
// Check for cancellation
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.Canceled) {
return fmt.Errorf("restore cancelled")
}
default:
}
// Check for shutdown before finalizing
if config.IsShouldShutdown() {
return fmt.Errorf("restore cancelled due to shutdown")
@@ -307,6 +331,15 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
}
if waitErr != nil {
// Check for cancellation again
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.Canceled) {
return fmt.Errorf("restore cancelled")
}
default:
}
if config.IsShouldShutdown() {
return fmt.Errorf("restore cancelled due to shutdown")
}
@@ -319,6 +352,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
// restoreViaFile downloads backup and uses parallel jobs for multi-CPU restore
func (uc *RestorePostgresqlBackupUsecase) restoreViaFile(
parentCtx context.Context,
originalDB *databases.Database,
pgBin string,
backup *backups_core.Backup,
@@ -354,6 +388,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaFile(
}
return uc.restoreFromStorage(
parentCtx,
originalDB,
pgBin,
args,
@@ -367,6 +402,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaFile(
// restoreFromStorage restores backup data from storage using pg_restore
func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
parentCtx context.Context,
database *databases.Database,
pgBin string,
args []string,
@@ -386,10 +422,10 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
isExcludeExtensions,
)
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
defer cancel()
// Monitor for shutdown and cancel context if needed
// Monitor for shutdown and parent cancellation
go func() {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
@@ -398,6 +434,9 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
select {
case <-ctx.Done():
return
case <-parentCtx.Done():
cancel()
return
case <-ticker.C:
if config.IsShouldShutdown() {
cancel()
@@ -624,12 +663,30 @@ func (uc *RestorePostgresqlBackupUsecase) executePgRestore(
waitErr := cmd.Wait()
stderrOutput := <-stderrCh
// Check for cancellation
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.Canceled) {
return fmt.Errorf("restore cancelled")
}
default:
}
// Check for shutdown before finalizing
if config.IsShouldShutdown() {
return fmt.Errorf("restore cancelled due to shutdown")
}
if waitErr != nil {
// Check for cancellation again
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.Canceled) {
return fmt.Errorf("restore cancelled")
}
default:
}
if config.IsShouldShutdown() {
return fmt.Errorf("restore cancelled due to shutdown")
}

View File

@@ -1,6 +1,7 @@
package usecases
import (
"context"
"errors"
backups_core "databasus-backend/internal/features/backups/backups/core"
@@ -22,6 +23,7 @@ type RestoreBackupUsecase struct {
}
func (uc *RestoreBackupUsecase) Execute(
ctx context.Context,
backupConfig *backups_config.BackupConfig,
restore restores_core.Restore,
originalDB *databases.Database,
@@ -33,6 +35,7 @@ func (uc *RestoreBackupUsecase) Execute(
switch originalDB.Type {
case databases.DatabaseTypePostgres:
return uc.restorePostgresqlBackupUsecase.Execute(
ctx,
originalDB,
restoringToDB,
backupConfig,
@@ -43,6 +46,7 @@ func (uc *RestoreBackupUsecase) Execute(
)
case databases.DatabaseTypeMysql:
return uc.restoreMysqlBackupUsecase.Execute(
ctx,
originalDB,
restoringToDB,
backupConfig,
@@ -52,6 +56,7 @@ func (uc *RestoreBackupUsecase) Execute(
)
case databases.DatabaseTypeMariadb:
return uc.restoreMariadbBackupUsecase.Execute(
ctx,
originalDB,
restoringToDB,
backupConfig,
@@ -61,6 +66,7 @@ func (uc *RestoreBackupUsecase) Execute(
)
case databases.DatabaseTypeMongodb:
return uc.restoreMongodbBackupUsecase.Execute(
ctx,
originalDB,
restoringToDB,
backupConfig,

View File

@@ -1,9 +1,13 @@
package storages
import (
"sync"
"sync/atomic"
audit_logs "databasus-backend/internal/features/audit_logs"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
)
var storageRepository = &StorageRepository{}
@@ -27,6 +31,21 @@ func GetStorageController() *StorageController {
return storageController
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(storageService)
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(storageService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -2,9 +2,11 @@ package task_cancellation
import (
"context"
"sync"
"sync/atomic"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/logger"
"sync"
"github.com/google/uuid"
)
@@ -20,6 +22,21 @@ func GetTaskCancelManager() *TaskCancelManager {
return taskCancelManager
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
taskCancelManager.StartSubscription()
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
taskCancelManager.StartSubscription()
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -0,0 +1,159 @@
package features
import (
"context"
"fmt"
"sync"
"testing"
"time"
"databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
"databasus-backend/internal/features/backups/backups/backuping"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/restores"
"databasus-backend/internal/features/restores/restoring"
"databasus-backend/internal/features/storages"
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
)
// Test_SetupDependencies_CalledTwice_LogsWarning verifies SetupDependencies is idempotent
func Test_SetupDependencies_CalledTwice_LogsWarning(t *testing.T) {
// Call each SetupDependencies twice - should not panic, only log warnings
audit_logs.SetupDependencies()
audit_logs.SetupDependencies()
backups.SetupDependencies()
backups.SetupDependencies()
backups_config.SetupDependencies()
backups_config.SetupDependencies()
databases.SetupDependencies()
databases.SetupDependencies()
healthcheck_config.SetupDependencies()
healthcheck_config.SetupDependencies()
notifiers.SetupDependencies()
notifiers.SetupDependencies()
restores.SetupDependencies()
restores.SetupDependencies()
storages.SetupDependencies()
storages.SetupDependencies()
task_cancellation.SetupDependencies()
task_cancellation.SetupDependencies()
// If we reach here without panic, test passes
t.Log("All SetupDependencies calls completed successfully (idempotent)")
}
// Test_SetupDependencies_ConcurrentCalls_Safe verifies thread safety
func Test_SetupDependencies_ConcurrentCalls_Safe(t *testing.T) {
var wg sync.WaitGroup
// Call SetupDependencies concurrently from 10 goroutines
for range 10 {
wg.Add(1)
go func() {
defer wg.Done()
audit_logs.SetupDependencies()
}()
}
wg.Wait()
t.Log("Concurrent SetupDependencies calls completed successfully")
}
// Test_BackgroundService_Run_CalledTwice_Panics verifies Run() panics on duplicate calls
func Test_BackgroundService_Run_CalledTwice_Panics(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Create a test background service
backgroundService := audit_logs.GetAuditLogBackgroundService()
// Start first Run() in goroutine
go func() {
backgroundService.Run(ctx)
}()
// Give first call time to initialize
time.Sleep(100 * time.Millisecond)
// Second call should panic
defer func() {
if r := recover(); r != nil {
expectedMsg := "*audit_logs.AuditLogBackgroundService.Run() called multiple times"
panicMsg := fmt.Sprintf("%v", r)
if panicMsg == expectedMsg {
t.Logf("Successfully caught panic: %v", r)
} else {
t.Errorf("Expected panic message '%s', got '%s'", expectedMsg, panicMsg)
}
} else {
t.Error("Expected panic on second Run() call, but did not panic")
}
}()
backgroundService.Run(ctx)
}
// Test_BackupsScheduler_Run_CalledTwice_Panics verifies scheduler panics on duplicate calls
func Test_BackupsScheduler_Run_CalledTwice_Panics(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
scheduler := backuping.GetBackupsScheduler()
// Start first Run() in goroutine
go func() {
scheduler.Run(ctx)
}()
// Give first call time to initialize
time.Sleep(100 * time.Millisecond)
// Second call should panic
defer func() {
if r := recover(); r != nil {
t.Logf("Successfully caught panic: %v", r)
} else {
t.Error("Expected panic on second Run() call, but did not panic")
}
}()
scheduler.Run(ctx)
}
// Test_RestoresScheduler_Run_CalledTwice_Panics verifies restore scheduler panics on duplicate calls
func Test_RestoresScheduler_Run_CalledTwice_Panics(t *testing.T) {
ctx := t.Context()
scheduler := restoring.GetRestoresScheduler()
// Start first Run() in goroutine
go func() {
scheduler.Run(ctx)
}()
// Give first call time to initialize
time.Sleep(100 * time.Millisecond)
// Second call should panic
defer func() {
if r := recover(); r != nil {
t.Logf("Successfully caught panic: %v", r)
} else {
t.Error("Expected panic on second Run() call, but did not panic")
}
}()
scheduler.Run(ctx)
}

View File

@@ -46,4 +46,8 @@ export const restoreApi = {
requestOptions,
);
},
async cancelRestore(restoreId: string) {
return apiHelper.fetchPostRaw(`${getApplicationServer()}/api/v1/restores/cancel/${restoreId}`);
},
};

View File

@@ -2,4 +2,5 @@ export enum RestoreStatus {
IN_PROGRESS = 'IN_PROGRESS',
COMPLETED = 'COMPLETED',
FAILED = 'FAILED',
CANCELED = 'CANCELED',
}

View File

@@ -1,5 +1,10 @@
import { CopyOutlined, ExclamationCircleOutlined, SyncOutlined } from '@ant-design/icons';
import { CheckCircleOutlined } from '@ant-design/icons';
import {
CheckCircleOutlined,
CloseCircleOutlined,
CopyOutlined,
ExclamationCircleOutlined,
SyncOutlined,
} from '@ant-design/icons';
import { App, Button, Modal, Spin, Tooltip } from 'antd';
import dayjs from 'dayjs';
import { useEffect, useRef, useState } from 'react';
@@ -8,6 +13,7 @@ import type { Backup } from '../../../entity/backups';
import { type Database, DatabaseType } from '../../../entity/databases';
import { type Restore, RestoreStatus, restoreApi } from '../../../entity/restores';
import { getUserTimeFormat } from '../../../shared/time';
import { ConfirmationComponent } from '../../../shared/ui';
import { EditDatabaseSpecificDataComponent } from '../../databases/ui/edit/EditDatabaseSpecificDataComponent';
interface Props {
@@ -70,6 +76,10 @@ export const RestoresComponent = ({ database, backup }: Props) => {
const [isShowRestore, setIsShowRestore] = useState(false);
const [cancellingRestoreId, setCancellingRestoreId] = useState<string | undefined>();
const [showCancelConfirmation, setShowCancelConfirmation] = useState(false);
const [restoreToCancelId, setRestoreToCancelId] = useState<string | undefined>();
const isReloadInProgress = useRef(false);
const loadRestores = async () => {
@@ -103,6 +113,18 @@ export const RestoresComponent = ({ database, backup }: Props) => {
}
};
const cancelRestore = async (restoreId: string) => {
setCancellingRestoreId(restoreId);
try {
await restoreApi.cancelRestore(restoreId);
await loadRestores();
} catch (e) {
alert((e as Error).message);
} finally {
setCancellingRestoreId(undefined);
}
};
useEffect(() => {
setIsLoading(true);
loadRestores().finally(() => setIsLoading(false));
@@ -190,40 +212,77 @@ export const RestoresComponent = ({ database, backup }: Props) => {
return (
<div key={restore.id} className="mb-1 rounded border border-gray-200 p-3 text-sm">
<div className="mb-1 flex">
<div className="w-[75px] min-w-[75px]">Status</div>
<div className="mb-1 flex items-center justify-between">
<div className="flex flex-1">
<div className="w-[75px] min-w-[75px]">Status</div>
{restore.status === RestoreStatus.FAILED && (
<Tooltip title="Click to see error details">
<div
className="flex cursor-pointer items-center text-red-600 underline"
onClick={() => setShowingRestoreError(restore)}
>
<ExclamationCircleOutlined
{restore.status === RestoreStatus.FAILED && (
<Tooltip title="Click to see error details">
<div
className="flex cursor-pointer items-center text-red-600 underline"
onClick={() => setShowingRestoreError(restore)}
>
<ExclamationCircleOutlined
className="mr-2"
style={{ fontSize: 16, color: '#ff0000' }}
/>
<div>Failed</div>
</div>
</Tooltip>
)}
{restore.status === RestoreStatus.COMPLETED && (
<div className="flex items-center">
<CheckCircleOutlined
className="mr-2"
style={{ fontSize: 16, color: '#ff0000' }}
style={{ fontSize: 16, color: '#008000' }}
/>
<div>Failed</div>
<div>Successful</div>
</div>
</Tooltip>
)}
)}
{restore.status === RestoreStatus.COMPLETED && (
<div className="flex items-center">
<CheckCircleOutlined
className="mr-2"
style={{ fontSize: 16, color: '#008000' }}
/>
{restore.status === RestoreStatus.CANCELED && (
<div className="flex items-center text-gray-500">
<CloseCircleOutlined
className="mr-2"
style={{ fontSize: 16, color: '#808080' }}
/>
<div>Successful</div>
</div>
)}
<div>Canceled</div>
</div>
)}
{restore.status === RestoreStatus.IN_PROGRESS && (
<div className="flex items-center font-bold text-blue-600">
<SyncOutlined spin />
<span className="ml-2">In progress</span>
</div>
)}
</div>
{restore.status === RestoreStatus.IN_PROGRESS && (
<div className="flex items-center font-bold text-blue-600">
<SyncOutlined spin />
<span className="ml-2">In progress</span>
<div className="ml-2">
{cancellingRestoreId === restore.id ? (
<SyncOutlined spin style={{ fontSize: 16 }} />
) : (
<Tooltip title="Cancel restore">
<CloseCircleOutlined
className="cursor-pointer"
onClick={() => {
if (cancellingRestoreId) return;
setRestoreToCancelId(restore.id);
setShowCancelConfirmation(true);
}}
style={{
color: '#ff0000',
fontSize: 16,
opacity: cancellingRestoreId ? 0.2 : 1,
}}
/>
</Tooltip>
)}
</div>
)}
</div>
@@ -289,6 +348,25 @@ export const RestoresComponent = ({ database, backup }: Props) => {
</div>
</Modal>
)}
{showCancelConfirmation && (
<ConfirmationComponent
onConfirm={() => {
setShowCancelConfirmation(false);
if (restoreToCancelId) {
cancelRestore(restoreToCancelId);
}
setRestoreToCancelId(undefined);
}}
onDecline={() => {
setShowCancelConfirmation(false);
setRestoreToCancelId(undefined);
}}
description="<strong>⚠️ Warning:</strong> Cancelling this restore will likely leave your database in a corrupted or incomplete state. You will need to recreate the database before attempting another restore.<br/><br/>Are you sure you want to cancel?"
actionText="Yes, cancel restore"
actionButtonColor="red"
/>
)}
</div>
);
};