mirror of
https://github.com/databasus/databasus.git
synced 2026-04-06 00:32:03 +02:00
Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dd1072e230 | ||
|
|
a495e5317a | ||
|
|
7eed647038 | ||
|
|
6973241e25 | ||
|
|
ab181f5b81 | ||
|
|
b60a0cc170 | ||
|
|
f319a497b3 | ||
|
|
bc870b3f8e | ||
|
|
15383c59eb | ||
|
|
d14c223a65 | ||
|
|
2c0a294027 | ||
|
|
5d851d73bd | ||
|
|
699913c251 | ||
|
|
a2e3f30a6d | ||
|
|
80f1174ecd |
@@ -27,3 +27,10 @@ repos:
|
||||
language: system
|
||||
files: ^backend/.*\.go$
|
||||
pass_filenames: false
|
||||
|
||||
- id: backend-go-mod-tidy
|
||||
name: Backend Go Mod Tidy
|
||||
entry: bash -c "cd backend && go mod tidy"
|
||||
language: system
|
||||
files: ^backend/.*\.go$
|
||||
pass_filenames: false
|
||||
|
||||
@@ -2,10 +2,10 @@ run:
|
||||
go run cmd/main.go
|
||||
|
||||
test:
|
||||
go test -p=1 -count=1 -failfast -timeout 10m ./internal/...
|
||||
go test -p=1 -count=1 -failfast -timeout 15m ./internal/...
|
||||
|
||||
lint:
|
||||
golangci-lint fmt && golangci-lint run
|
||||
golangci-lint fmt ./cmd/... ./internal/... && golangci-lint run ./cmd/... ./internal/...
|
||||
|
||||
migration-create:
|
||||
goose create $(name) sql
|
||||
|
||||
@@ -15,6 +15,9 @@ import (
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/disk"
|
||||
@@ -54,16 +57,23 @@ func main() {
|
||||
log := logger.GetLogger()
|
||||
|
||||
cache_utils.TestCacheConnection()
|
||||
err := cache_utils.ClearAllCache()
|
||||
if err != nil {
|
||||
log.Error("Failed to clear cache", "error", err)
|
||||
os.Exit(1)
|
||||
|
||||
if config.GetEnv().IsPrimaryNode {
|
||||
err := cache_utils.ClearAllCache()
|
||||
if err != nil {
|
||||
log.Error("Failed to clear cache", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
runMigrations(log)
|
||||
if config.GetEnv().IsPrimaryNode {
|
||||
runMigrations(log)
|
||||
} else {
|
||||
log.Info("Skipping migrations (IS_PRIMARY_NODE is false)")
|
||||
}
|
||||
|
||||
// create directories that used for backups and restore
|
||||
err = files_utils.EnsureDirectories([]string{
|
||||
err := files_utils.EnsureDirectories([]string{
|
||||
config.GetEnv().TempFolder,
|
||||
config.GetEnv().DataFolder,
|
||||
})
|
||||
@@ -104,7 +114,9 @@ func main() {
|
||||
enableCors(ginApp)
|
||||
setUpRoutes(ginApp)
|
||||
setUpDependencies()
|
||||
|
||||
runBackgroundTasks(log)
|
||||
|
||||
mountFrontend(ginApp)
|
||||
|
||||
startServerWithGracefulShutdown(log, ginApp)
|
||||
@@ -227,35 +239,64 @@ func setUpDependencies() {
|
||||
notifiers.SetupDependencies()
|
||||
storages.SetupDependencies()
|
||||
backups_config.SetupDependencies()
|
||||
backups_cancellation.SetupDependencies()
|
||||
}
|
||||
|
||||
func runBackgroundTasks(log *slog.Logger) {
|
||||
log.Info("Preparing to run background tasks...")
|
||||
|
||||
err := files_utils.CleanFolder(config.GetEnv().TempFolder)
|
||||
if err != nil {
|
||||
log.Error("Failed to clean temp folder", "error", err)
|
||||
// Create context that will be cancelled on shutdown
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Set up signal handling for graceful shutdown
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, os.Interrupt, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-quit
|
||||
log.Info("Shutdown signal received, cancelling all background tasks")
|
||||
cancel()
|
||||
}()
|
||||
|
||||
if config.GetEnv().IsPrimaryNode {
|
||||
log.Info("Starting primary node background tasks...")
|
||||
|
||||
err := files_utils.CleanFolder(config.GetEnv().TempFolder)
|
||||
if err != nil {
|
||||
log.Error("Failed to clean temp folder", "error", err)
|
||||
}
|
||||
|
||||
go runWithPanicLogging(log, "backup background service", func() {
|
||||
backuping.GetBackupsScheduler().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "restore background service", func() {
|
||||
restores.GetRestoreBackgroundService().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "healthcheck attempt background service", func() {
|
||||
healthcheck_attempt.GetHealthcheckAttemptBackgroundService().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "audit log cleanup background service", func() {
|
||||
audit_logs.GetAuditLogBackgroundService().Run(ctx)
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "download token cleanup background service", func() {
|
||||
backups_download.GetDownloadTokenBackgroundService().Run(ctx)
|
||||
})
|
||||
} else {
|
||||
log.Info("Skipping primary node tasks as not primary node")
|
||||
}
|
||||
|
||||
go runWithPanicLogging(log, "backup background service", func() {
|
||||
backups.GetBackupBackgroundService().Run()
|
||||
})
|
||||
if config.GetEnv().IsBackupNode {
|
||||
log.Info("Starting backup node background tasks...")
|
||||
|
||||
go runWithPanicLogging(log, "restore background service", func() {
|
||||
restores.GetRestoreBackgroundService().Run()
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "healthcheck attempt background service", func() {
|
||||
healthcheck_attempt.GetHealthcheckAttemptBackgroundService().Run()
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "audit log cleanup background service", func() {
|
||||
audit_logs.GetAuditLogBackgroundService().Run()
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "download token cleanup background service", func() {
|
||||
backups.GetDownloadTokenBackgroundService().Run()
|
||||
})
|
||||
go runWithPanicLogging(log, "backup node", func() {
|
||||
backuping.GetBackuperNode().Run(ctx)
|
||||
})
|
||||
} else {
|
||||
log.Info("Skipping backup node tasks as not backup node")
|
||||
}
|
||||
}
|
||||
|
||||
func runWithPanicLogging(log *slog.Logger, serviceName string, fn func()) {
|
||||
|
||||
@@ -28,7 +28,6 @@ require (
|
||||
github.com/valkey-io/valkey-go v1.0.70
|
||||
go.mongodb.org/mongo-driver v1.17.6
|
||||
golang.org/x/crypto v0.46.0
|
||||
golang.org/x/time v0.14.0
|
||||
gorm.io/driver/postgres v1.5.11
|
||||
gorm.io/gorm v1.26.1
|
||||
)
|
||||
@@ -186,6 +185,7 @@ require (
|
||||
go.yaml.in/yaml/v2 v2.4.3 // indirect
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||
golang.org/x/term v0.38.0 // indirect
|
||||
golang.org/x/time v0.14.0 // indirect
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
|
||||
gopkg.in/validator.v2 v2.0.1 // indirect
|
||||
moul.io/http2curl/v2 v2.3.0 // indirect
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/ilyakaznacheev/cleanenv"
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
@@ -29,6 +30,12 @@ type EnvVariables struct {
|
||||
MariadbInstallDir string `env:"MARIADB_INSTALL_DIR"`
|
||||
MongodbInstallDir string `env:"MONGODB_INSTALL_DIR"`
|
||||
|
||||
NodeID string
|
||||
IsManyNodesMode bool `env:"IS_MANY_NODES_MODE"`
|
||||
IsPrimaryNode bool `env:"IS_PRIMARY_NODE"`
|
||||
IsBackupNode bool `env:"IS_BACKUP_NODE"`
|
||||
NodeNetworkThroughputMBs int `env:"NODE_NETWORK_THROUGHPUT_MBPS"`
|
||||
|
||||
DataFolder string
|
||||
TempFolder string
|
||||
SecretKeyPath string
|
||||
@@ -196,6 +203,16 @@ func loadEnvVariables() {
|
||||
env.MongodbInstallDir = filepath.Join(backendRoot, "tools", "mongodb")
|
||||
tools.VerifyMongodbInstallation(log, env.EnvMode, env.MongodbInstallDir)
|
||||
|
||||
env.NodeID = uuid.New().String()
|
||||
if env.NodeNetworkThroughputMBs == 0 {
|
||||
env.NodeNetworkThroughputMBs = 125 // 1 Gbit/s
|
||||
}
|
||||
|
||||
if !env.IsManyNodesMode {
|
||||
env.IsPrimaryNode = true
|
||||
env.IsBackupNode = true
|
||||
}
|
||||
|
||||
// Valkey
|
||||
if env.ValkeyHost == "" {
|
||||
log.Error("VALKEY_HOST is empty")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
)
|
||||
@@ -11,23 +11,25 @@ type AuditLogBackgroundService struct {
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *AuditLogBackgroundService) Run() {
|
||||
func (s *AuditLogBackgroundService) Run(ctx context.Context) {
|
||||
s.logger.Info("Starting audit log cleanup background service")
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
if config.IsShouldShutdown() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.cleanOldAuditLogs(); err != nil {
|
||||
s.logger.Error("Failed to clean old audit logs", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.cleanOldAuditLogs(); err != nil {
|
||||
s.logger.Error("Failed to clean old audit logs", "error", err)
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,254 +0,0 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/storages"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/period"
|
||||
"log/slog"
|
||||
"time"
|
||||
)
|
||||
|
||||
type BackupBackgroundService struct {
|
||||
backupService *BackupService
|
||||
backupRepository *BackupRepository
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
storageService *storages.StorageService
|
||||
|
||||
lastBackupTime time.Time
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *BackupBackgroundService) Run() {
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
|
||||
if err := s.failBackupsInProgress(); err != nil {
|
||||
s.logger.Error("Failed to fail backups in progress", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
if config.IsShouldShutdown() {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.cleanOldBackups(); err != nil {
|
||||
s.logger.Error("Failed to clean old backups", "error", err)
|
||||
}
|
||||
|
||||
if err := s.runPendingBackups(); err != nil {
|
||||
s.logger.Error("Failed to run pending backups", "error", err)
|
||||
}
|
||||
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
time.Sleep(1 * time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupBackgroundService) IsBackupsWorkerRunning() bool {
|
||||
// if last backup time is more than 5 minutes ago, return false
|
||||
return s.lastBackupTime.After(time.Now().UTC().Add(-5 * time.Minute))
|
||||
}
|
||||
|
||||
func (s *BackupBackgroundService) failBackupsInProgress() error {
|
||||
backupsInProgress, err := s.backupRepository.FindByStatus(BackupStatusInProgress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backup := range backupsInProgress {
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(backup.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
failMessage := "Backup failed due to application restart"
|
||||
backup.FailMessage = &failMessage
|
||||
backup.Status = BackupStatusFailed
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
s.backupService.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupFailed,
|
||||
&failMessage,
|
||||
)
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupBackgroundService) cleanOldBackups() error {
|
||||
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
backupStorePeriod := backupConfig.StorePeriod
|
||||
|
||||
if backupStorePeriod == period.PeriodForever {
|
||||
continue
|
||||
}
|
||||
|
||||
storeDuration := backupStorePeriod.ToDuration()
|
||||
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
|
||||
|
||||
oldBackups, err := s.backupRepository.FindBackupsBeforeDate(
|
||||
backupConfig.DatabaseID,
|
||||
dateBeforeBackupsShouldBeDeleted,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to find old backups for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, backup := range oldBackups {
|
||||
storage, err := s.storageService.GetStorageByID(backup.StorageID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to get storage by ID",
|
||||
"storageId",
|
||||
backup.StorageID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
err = storage.DeleteFile(encryptor, backup.ID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to delete backup file", "backupId", backup.ID, "error", err)
|
||||
}
|
||||
|
||||
if err := s.backupRepository.DeleteByID(backup.ID); err != nil {
|
||||
s.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
s.logger.Info(
|
||||
"Deleted old backup",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupBackgroundService) runPendingBackups() error {
|
||||
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
if backupConfig.BackupInterval == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
lastBackup, err := s.backupRepository.FindLastByDatabaseID(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to get last backup for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
var lastBackupTime *time.Time
|
||||
if lastBackup != nil {
|
||||
lastBackupTime = &lastBackup.CreatedAt
|
||||
}
|
||||
|
||||
remainedBackupTryCount := s.GetRemainedBackupTryCount(lastBackup)
|
||||
|
||||
if backupConfig.BackupInterval.ShouldTriggerBackup(time.Now().UTC(), lastBackupTime) ||
|
||||
remainedBackupTryCount > 0 {
|
||||
s.logger.Info(
|
||||
"Triggering scheduled backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"intervalType",
|
||||
backupConfig.BackupInterval.Interval,
|
||||
)
|
||||
|
||||
go s.backupService.MakeBackup(backupConfig.DatabaseID, remainedBackupTryCount == 1)
|
||||
s.logger.Info(
|
||||
"Successfully triggered scheduled backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRemainedBackupTryCount returns the number of remaining backup tries for a given backup.
|
||||
// If the backup is not failed or the backup config does not allow retries, it returns 0.
|
||||
// If the backup is failed and the backup config allows retries, it returns the number of remaining tries.
|
||||
// If the backup is failed and the backup config does not allow retries, it returns 0.
|
||||
func (s *BackupBackgroundService) GetRemainedBackupTryCount(lastBackup *Backup) int {
|
||||
if lastBackup == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
if lastBackup.Status != BackupStatusFailed {
|
||||
return 0
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(lastBackup.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
if !backupConfig.IsRetryIfFailed {
|
||||
return 0
|
||||
}
|
||||
|
||||
maxFailedTriesCount := backupConfig.MaxFailedTriesCount
|
||||
|
||||
lastBackups, err := s.backupRepository.FindByDatabaseIDWithLimit(
|
||||
lastBackup.DatabaseID,
|
||||
maxFailedTriesCount,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to find last backups by database ID", "error", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
lastFailedBackups := make([]*Backup, 0)
|
||||
|
||||
for _, backup := range lastBackups {
|
||||
if backup.Status == BackupStatusFailed {
|
||||
lastFailedBackups = append(lastFailedBackups, backup)
|
||||
}
|
||||
}
|
||||
|
||||
return maxFailedTriesCount - len(lastFailedBackups)
|
||||
}
|
||||
344
backend/internal/features/backups/backups/backuping/backuper.go
Normal file
344
backend/internal/features/backups/backups/backuping/backuper.go
Normal file
@@ -0,0 +1,344 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"databasus-backend/internal/config"
|
||||
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/storages"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type BackuperNode struct {
|
||||
databaseService *databases.DatabaseService
|
||||
fieldEncryptor util_encryption.FieldEncryptor
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
backupRepository *backups_core.BackupRepository
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
storageService *storages.StorageService
|
||||
notificationSender backups_core.NotificationSender
|
||||
backupCancelManager *backups_cancellation.BackupCancelManager
|
||||
nodesRegistry *BackupNodesRegistry
|
||||
logger *slog.Logger
|
||||
createBackupUseCase backups_core.CreateBackupUsecase
|
||||
nodeID uuid.UUID
|
||||
|
||||
lastHeartbeat time.Time
|
||||
}
|
||||
|
||||
func (n *BackuperNode) Run(ctx context.Context) {
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
|
||||
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
|
||||
|
||||
backupNode := BackupNode{
|
||||
ID: n.nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: time.Now().UTC(),
|
||||
}
|
||||
|
||||
if err := n.nodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
|
||||
n.logger.Error("Failed to register node in registry", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
|
||||
n.MakeBackup(backupID, isCallNotifier)
|
||||
if err := n.nodesRegistry.PublishBackupCompletion(n.nodeID.String(), backupID); err != nil {
|
||||
n.logger.Error(
|
||||
"Failed to publish backup completion",
|
||||
"error",
|
||||
err,
|
||||
"backupID",
|
||||
backupID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if err := n.nodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID.String(), backupHandler); err != nil {
|
||||
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := n.nodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
|
||||
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(15 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
|
||||
|
||||
if err := n.nodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
|
||||
n.logger.Error("Failed to unregister node from registry", "error", err)
|
||||
}
|
||||
|
||||
return
|
||||
case <-ticker.C:
|
||||
n.sendHeartbeat(&backupNode)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (n *BackuperNode) IsBackuperRunning() bool {
|
||||
return n.lastHeartbeat.After(time.Now().UTC().Add(-5 * time.Minute))
|
||||
}
|
||||
|
||||
func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
|
||||
backup, err := n.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get backup by ID", "backupId", backupID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
databaseID := backup.DatabaseID
|
||||
|
||||
database, err := n.databaseService.GetDatabaseByID(databaseID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get database by ID", "databaseId", databaseID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
backupConfig, err := n.backupConfigService.GetBackupConfigByDbId(databaseID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if backupConfig.StorageID == nil {
|
||||
n.logger.Error("Backup config storage ID is not defined")
|
||||
return
|
||||
}
|
||||
|
||||
storage, err := n.storageService.GetStorageByID(*backupConfig.StorageID)
|
||||
if err != nil {
|
||||
n.logger.Error("Failed to get storage by ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now().UTC()
|
||||
|
||||
backupProgressListener := func(
|
||||
completedMBs float64,
|
||||
) {
|
||||
backup.BackupSizeMb = completedMBs
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to update backup progress", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
n.backupCancelManager.RegisterBackup(backup.ID, cancel)
|
||||
defer n.backupCancelManager.UnregisterBackup(backup.ID)
|
||||
|
||||
backupMetadata, err := n.createBackupUseCase.Execute(
|
||||
ctx,
|
||||
backup.ID,
|
||||
backupConfig,
|
||||
database,
|
||||
storage,
|
||||
backupProgressListener,
|
||||
)
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
|
||||
// Check if backup was cancelled (not due to shutdown)
|
||||
isCancelled := strings.Contains(errMsg, "backup cancelled") ||
|
||||
strings.Contains(errMsg, "context canceled") ||
|
||||
errors.Is(err, context.Canceled)
|
||||
isShutdown := strings.Contains(errMsg, "shutdown")
|
||||
|
||||
if isCancelled && !isShutdown {
|
||||
backup.Status = backups_core.BackupStatusCanceled
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to save cancelled backup", "error", err)
|
||||
}
|
||||
|
||||
// Delete partial backup from storage
|
||||
storage, storageErr := n.storageService.GetStorageByID(backup.StorageID)
|
||||
if storageErr == nil {
|
||||
if deleteErr := storage.DeleteFile(n.fieldEncryptor, backup.ID); deleteErr != nil {
|
||||
n.logger.Error(
|
||||
"Failed to delete partial backup file",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
deleteErr,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
backup.FailMessage = &errMsg
|
||||
backup.Status = backups_core.BackupStatusFailed
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
if updateErr := n.databaseService.SetBackupError(databaseID, errMsg); updateErr != nil {
|
||||
n.logger.Error(
|
||||
"Failed to update database last backup time",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"error",
|
||||
updateErr,
|
||||
)
|
||||
}
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to save backup", "error", err)
|
||||
}
|
||||
|
||||
n.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupFailed,
|
||||
&errMsg,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
backup.Status = backups_core.BackupStatusCompleted
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
// Update backup with encryption metadata if provided
|
||||
if backupMetadata != nil {
|
||||
backup.EncryptionSalt = backupMetadata.EncryptionSalt
|
||||
backup.EncryptionIV = backupMetadata.EncryptionIV
|
||||
backup.Encryption = backupMetadata.Encryption
|
||||
}
|
||||
|
||||
if err := n.backupRepository.Save(backup); err != nil {
|
||||
n.logger.Error("Failed to save backup", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Update database last backup time
|
||||
now := time.Now().UTC()
|
||||
if updateErr := n.databaseService.SetLastBackupTime(databaseID, now); updateErr != nil {
|
||||
n.logger.Error(
|
||||
"Failed to update database last backup time",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"error",
|
||||
updateErr,
|
||||
)
|
||||
}
|
||||
|
||||
if backup.Status != backups_core.BackupStatusCompleted && !isCallNotifier {
|
||||
return
|
||||
}
|
||||
|
||||
n.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupSuccess,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
func (n *BackuperNode) SendBackupNotification(
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
backup *backups_core.Backup,
|
||||
notificationType backups_config.BackupNotificationType,
|
||||
errorMessage *string,
|
||||
) {
|
||||
database, err := n.databaseService.GetDatabaseByID(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
workspace, err := n.workspaceService.GetWorkspaceByID(*database.WorkspaceID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, notifier := range database.Notifiers {
|
||||
if !slices.Contains(
|
||||
backupConfig.SendNotificationsOn,
|
||||
notificationType,
|
||||
) {
|
||||
continue
|
||||
}
|
||||
|
||||
title := ""
|
||||
switch notificationType {
|
||||
case backups_config.NotificationBackupFailed:
|
||||
title = fmt.Sprintf(
|
||||
"❌ Backup failed for database \"%s\" (workspace \"%s\")",
|
||||
database.Name,
|
||||
workspace.Name,
|
||||
)
|
||||
case backups_config.NotificationBackupSuccess:
|
||||
title = fmt.Sprintf(
|
||||
"✅ Backup completed for database \"%s\" (workspace \"%s\")",
|
||||
database.Name,
|
||||
workspace.Name,
|
||||
)
|
||||
}
|
||||
|
||||
message := ""
|
||||
if errorMessage != nil {
|
||||
message = *errorMessage
|
||||
} else {
|
||||
// Format size conditionally
|
||||
var sizeStr string
|
||||
if backup.BackupSizeMb < 1024 {
|
||||
sizeStr = fmt.Sprintf("%.2f MB", backup.BackupSizeMb)
|
||||
} else {
|
||||
sizeGB := backup.BackupSizeMb / 1024
|
||||
sizeStr = fmt.Sprintf("%.2f GB", sizeGB)
|
||||
}
|
||||
|
||||
// Format duration as "0m 0s 0ms"
|
||||
totalMs := backup.BackupDurationMs
|
||||
minutes := totalMs / (1000 * 60)
|
||||
seconds := (totalMs % (1000 * 60)) / 1000
|
||||
durationStr := fmt.Sprintf("%dm %ds", minutes, seconds)
|
||||
|
||||
message = fmt.Sprintf(
|
||||
"Backup completed successfully in %s.\nCompressed backup size: %s",
|
||||
durationStr,
|
||||
sizeStr,
|
||||
)
|
||||
}
|
||||
|
||||
n.notificationSender.SendNotification(
|
||||
¬ifier,
|
||||
title,
|
||||
message,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (n *BackuperNode) sendHeartbeat(backupNode *BackupNode) {
|
||||
n.lastHeartbeat = time.Now().UTC()
|
||||
backupNode.LastHeartbeat = time.Now().UTC()
|
||||
if err := n.nodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *backupNode); err != nil {
|
||||
n.logger.Error("Failed to send heartbeat", "error", err)
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package backups
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -8,17 +8,15 @@ import (
|
||||
"time"
|
||||
|
||||
common "databasus-backend/internal/features/backups/backups/common"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -26,6 +24,7 @@ import (
|
||||
)
|
||||
|
||||
func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
@@ -50,23 +49,19 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
|
||||
t.Run("BackupFailed_FailNotificationSent", func(t *testing.T) {
|
||||
mockNotificationSender := &MockNotificationSender{}
|
||||
backupService := &BackupService{
|
||||
databases.GetDatabaseService(),
|
||||
storages.GetStorageService(),
|
||||
backupRepository,
|
||||
notifiers.GetNotifierService(),
|
||||
mockNotificationSender,
|
||||
backups_config.GetBackupConfigService(),
|
||||
encryption_secrets.GetSecretKeyService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
&CreateFailedBackupUsecase{},
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
nil,
|
||||
NewBackupContextManager(),
|
||||
nil,
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.notificationSender = mockNotificationSender
|
||||
backuperNode.createBackupUseCase = &CreateFailedBackupUsecase{}
|
||||
|
||||
// Create a backup record directly that will be looked up by MakeBackup
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err := backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Set up expectations
|
||||
mockNotificationSender.On("SendNotification",
|
||||
@@ -79,7 +74,7 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
}),
|
||||
).Once()
|
||||
|
||||
backupService.MakeBackup(database.ID, true)
|
||||
backuperNode.MakeBackup(backup.ID, true)
|
||||
|
||||
// Verify all expectations were met
|
||||
mockNotificationSender.AssertExpectations(t)
|
||||
@@ -87,6 +82,19 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
|
||||
t.Run("BackupSuccess_SuccessNotificationSent", func(t *testing.T) {
|
||||
mockNotificationSender := &MockNotificationSender{}
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.notificationSender = mockNotificationSender
|
||||
backuperNode.createBackupUseCase = &CreateSuccessBackupUsecase{}
|
||||
|
||||
// Create a backup record directly that will be looked up by MakeBackup
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err := backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Set up expectations
|
||||
mockNotificationSender.On("SendNotification",
|
||||
@@ -99,25 +107,7 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
}),
|
||||
).Once()
|
||||
|
||||
backupService := &BackupService{
|
||||
databases.GetDatabaseService(),
|
||||
storages.GetStorageService(),
|
||||
backupRepository,
|
||||
notifiers.GetNotifierService(),
|
||||
mockNotificationSender,
|
||||
backups_config.GetBackupConfigService(),
|
||||
encryption_secrets.GetSecretKeyService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
&CreateSuccessBackupUsecase{},
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
nil,
|
||||
NewBackupContextManager(),
|
||||
nil,
|
||||
}
|
||||
|
||||
backupService.MakeBackup(database.ID, true)
|
||||
backuperNode.MakeBackup(backup.ID, true)
|
||||
|
||||
// Verify all expectations were met
|
||||
mockNotificationSender.AssertExpectations(t)
|
||||
@@ -125,23 +115,19 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
|
||||
t.Run("BackupSuccess_VerifyNotificationContent", func(t *testing.T) {
|
||||
mockNotificationSender := &MockNotificationSender{}
|
||||
backupService := &BackupService{
|
||||
databases.GetDatabaseService(),
|
||||
storages.GetStorageService(),
|
||||
backupRepository,
|
||||
notifiers.GetNotifierService(),
|
||||
mockNotificationSender,
|
||||
backups_config.GetBackupConfigService(),
|
||||
encryption_secrets.GetSecretKeyService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
&CreateSuccessBackupUsecase{},
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
nil,
|
||||
NewBackupContextManager(),
|
||||
nil,
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
backuperNode.notificationSender = mockNotificationSender
|
||||
backuperNode.createBackupUseCase = &CreateSuccessBackupUsecase{}
|
||||
|
||||
// Create a backup record directly that will be looked up by MakeBackup
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err := backupRepository.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// capture arguments
|
||||
var capturedNotifier *notifiers.Notifier
|
||||
@@ -158,7 +144,7 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
capturedMessage = args.Get(2).(string)
|
||||
}).Once()
|
||||
|
||||
backupService.MakeBackup(database.ID, true)
|
||||
backuperNode.MakeBackup(backup.ID, true)
|
||||
|
||||
// Verify expectations were met
|
||||
mockNotificationSender.AssertExpectations(t)
|
||||
77
backend/internal/features/backups/backups/backuping/di.go
Normal file
77
backend/internal/features/backups/backups/backuping/di.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/usecases"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var backupRepository = &backups_core.BackupRepository{}
|
||||
|
||||
var backupCancelManager = backups_cancellation.GetBackupCancelManager()
|
||||
|
||||
var nodesRegistry = &BackupNodesRegistry{
|
||||
cache_utils.GetValkeyClient(),
|
||||
logger.GetLogger(),
|
||||
cache_utils.DefaultCacheTimeout,
|
||||
cache_utils.NewPubSubManager(),
|
||||
cache_utils.NewPubSubManager(),
|
||||
}
|
||||
|
||||
func getNodeID() uuid.UUID {
|
||||
nodeIDStr := config.GetEnv().NodeID
|
||||
nodeID, err := uuid.Parse(nodeIDStr)
|
||||
if err != nil {
|
||||
logger.GetLogger().Error("Failed to parse node ID from config", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
return nodeID
|
||||
}
|
||||
|
||||
var backuperNode = &BackuperNode{
|
||||
databases.GetDatabaseService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
backupCancelManager,
|
||||
nodesRegistry,
|
||||
logger.GetLogger(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
getNodeID(),
|
||||
time.Time{},
|
||||
}
|
||||
|
||||
var backupsScheduler = &BackupsScheduler{
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
backupCancelManager,
|
||||
nodesRegistry,
|
||||
time.Now().UTC(),
|
||||
logger.GetLogger(),
|
||||
make(map[uuid.UUID]BackupToNodeRelation),
|
||||
backuperNode,
|
||||
}
|
||||
|
||||
func GetBackupsScheduler() *BackupsScheduler {
|
||||
return backupsScheduler
|
||||
}
|
||||
|
||||
func GetBackuperNode() *BackuperNode {
|
||||
return backuperNode
|
||||
}
|
||||
34
backend/internal/features/backups/backups/backuping/dto.go
Normal file
34
backend/internal/features/backups/backups/backuping/dto.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type BackupNode struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
ThroughputMBs int `json:"throughputMBs"`
|
||||
LastHeartbeat time.Time `json:"lastHeartbeat"`
|
||||
}
|
||||
|
||||
type BackupNodeStats struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
ActiveBackups int `json:"activeBackups"`
|
||||
}
|
||||
|
||||
type BackupSubmitMessage struct {
|
||||
NodeID string `json:"nodeId"`
|
||||
BackupID string `json:"backupId"`
|
||||
IsCallNotifier bool `json:"isCallNotifier"`
|
||||
}
|
||||
|
||||
type BackupCompletionMessage struct {
|
||||
NodeID string `json:"nodeId"`
|
||||
BackupID string `json:"backupId"`
|
||||
}
|
||||
|
||||
type BackupToNodeRelation struct {
|
||||
NodeID uuid.UUID `json:"nodeId"`
|
||||
BackupsIDs []uuid.UUID `json:"backupsIds"`
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package backups
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
448
backend/internal/features/backups/backups/backuping/registry.go
Normal file
448
backend/internal/features/backups/backups/backuping/registry.go
Normal file
@@ -0,0 +1,448 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/valkey-io/valkey-go"
|
||||
)
|
||||
|
||||
const (
|
||||
nodeInfoKeyPrefix = "node:"
|
||||
nodeInfoKeySuffix = ":info"
|
||||
nodeActiveBackupsPrefix = "node:"
|
||||
nodeActiveBackupsSuffix = ":active_backups"
|
||||
backupSubmitChannel = "backup:submit"
|
||||
backupCompletionChannel = "backup:completion"
|
||||
)
|
||||
|
||||
// BackupNodesRegistry helps to sync backups scheduler and backup nodes.
|
||||
// Features:
|
||||
// - Track node availability and load level
|
||||
// - Assign from scheduler to node backups needed to be processed
|
||||
// - Notify scheduler from node about backup completion
|
||||
type BackupNodesRegistry struct {
|
||||
client valkey.Client
|
||||
logger *slog.Logger
|
||||
timeout time.Duration
|
||||
pubsubBackups *cache_utils.PubSubManager
|
||||
pubsubCompletions *cache_utils.PubSubManager
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) GetAvailableNodes() ([]BackupNode, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
var allKeys []string
|
||||
cursor := uint64(0)
|
||||
pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix
|
||||
|
||||
for {
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return nil, fmt.Errorf("failed to scan node keys: %w", result.Error())
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse scan result: %w", err)
|
||||
}
|
||||
|
||||
allKeys = append(allKeys, scanResult.Elements...)
|
||||
|
||||
cursor = scanResult.Cursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(allKeys) == 0 {
|
||||
return []BackupNode{}, nil
|
||||
}
|
||||
|
||||
keyDataMap, err := r.pipelineGetKeys(allKeys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pipeline get node keys: %w", err)
|
||||
}
|
||||
|
||||
var nodes []BackupNode
|
||||
for key, data := range keyDataMap {
|
||||
var node BackupNode
|
||||
if err := json.Unmarshal(data, &node); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal node data", "key", key, "error", err)
|
||||
continue
|
||||
}
|
||||
nodes = append(nodes, node)
|
||||
}
|
||||
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) GetBackupNodesStats() ([]BackupNodeStats, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
var allKeys []string
|
||||
cursor := uint64(0)
|
||||
pattern := nodeActiveBackupsPrefix + "*" + nodeActiveBackupsSuffix
|
||||
|
||||
for {
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(100).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return nil, fmt.Errorf("failed to scan active backups keys: %w", result.Error())
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse scan result: %w", err)
|
||||
}
|
||||
|
||||
allKeys = append(allKeys, scanResult.Elements...)
|
||||
|
||||
cursor = scanResult.Cursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(allKeys) == 0 {
|
||||
return []BackupNodeStats{}, nil
|
||||
}
|
||||
|
||||
keyDataMap, err := r.pipelineGetKeys(allKeys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pipeline get active backups keys: %w", err)
|
||||
}
|
||||
|
||||
var stats []BackupNodeStats
|
||||
for key, data := range keyDataMap {
|
||||
nodeID := r.extractNodeIDFromKey(key, nodeActiveBackupsPrefix, nodeActiveBackupsSuffix)
|
||||
|
||||
count, err := r.parseIntFromBytes(data)
|
||||
if err != nil {
|
||||
r.logger.Warn("Failed to parse active backups count", "key", key, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
stat := BackupNodeStats{
|
||||
ID: nodeID,
|
||||
ActiveBackups: int(count),
|
||||
}
|
||||
stats = append(stats, stat)
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) IncrementBackupsInProgress(nodeID string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID, nodeActiveBackupsSuffix)
|
||||
result := r.client.Do(ctx, r.client.B().Incr().Key(key).Build())
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to increment backups in progress for node %s: %w",
|
||||
nodeID,
|
||||
result.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) DecrementBackupsInProgress(nodeID string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID, nodeActiveBackupsSuffix)
|
||||
result := r.client.Do(ctx, r.client.B().Decr().Key(key).Build())
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to decrement backups in progress for node %s: %w",
|
||||
nodeID,
|
||||
result.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
newValue, err := result.AsInt64()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse decremented value for node %s: %w", nodeID, err)
|
||||
}
|
||||
|
||||
if newValue < 0 {
|
||||
setCtx, setCancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
r.client.Do(setCtx, r.client.B().Set().Key(key).Value("0").Build())
|
||||
setCancel()
|
||||
r.logger.Warn("Active backups counter went below 0, reset to 0", "nodeID", nodeID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) HearthbeatNodeInRegistry(now time.Time, backupNode BackupNode) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
backupNode.LastHeartbeat = now
|
||||
|
||||
data, err := json.Marshal(backupNode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal backup node: %w", err)
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, backupNode.ID.String(), nodeInfoKeySuffix)
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Set().Key(key).Value(string(data)).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to register node %s: %w", backupNode.ID, result.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) UnregisterNodeFromRegistry(backupNode BackupNode) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, backupNode.ID.String(), nodeInfoKeySuffix)
|
||||
counterKey := fmt.Sprintf(
|
||||
"%s%s%s",
|
||||
nodeActiveBackupsPrefix,
|
||||
backupNode.ID.String(),
|
||||
nodeActiveBackupsSuffix,
|
||||
)
|
||||
|
||||
result := r.client.Do(
|
||||
ctx,
|
||||
r.client.B().Del().Key(infoKey, counterKey).Build(),
|
||||
)
|
||||
|
||||
if result.Error() != nil {
|
||||
return fmt.Errorf("failed to unregister node %s: %w", backupNode.ID, result.Error())
|
||||
}
|
||||
|
||||
r.logger.Info("Unregistered node from registry", "nodeID", backupNode.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) AssignBackupToNode(
|
||||
targetNodeID string,
|
||||
backupID uuid.UUID,
|
||||
isCallNotifier bool,
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
message := BackupSubmitMessage{
|
||||
NodeID: targetNodeID,
|
||||
BackupID: backupID.String(),
|
||||
IsCallNotifier: isCallNotifier,
|
||||
}
|
||||
|
||||
messageJSON, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal backup submit message: %w", err)
|
||||
}
|
||||
|
||||
err = r.pubsubBackups.Publish(ctx, backupSubmitChannel, string(messageJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to publish backup submit message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) SubscribeNodeForBackupsAssignment(
|
||||
nodeID string,
|
||||
handler func(backupID uuid.UUID, isCallNotifier bool),
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
wrappedHandler := func(message string) {
|
||||
var msg BackupSubmitMessage
|
||||
if err := json.Unmarshal([]byte(message), &msg); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal backup submit message", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if msg.NodeID != nodeID {
|
||||
return
|
||||
}
|
||||
|
||||
backupID, err := uuid.Parse(msg.BackupID)
|
||||
if err != nil {
|
||||
r.logger.Warn(
|
||||
"Failed to parse backup ID from message",
|
||||
"backupId",
|
||||
msg.BackupID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
handler(backupID, msg.IsCallNotifier)
|
||||
}
|
||||
|
||||
err := r.pubsubBackups.Subscribe(ctx, backupSubmitChannel, wrappedHandler)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to backup submit channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Subscribed to backup submit channel", "nodeID", nodeID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) UnsubscribeNodeForBackupsAssignments() error {
|
||||
err := r.pubsubBackups.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unsubscribe from backup submit channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Unsubscribed from backup submit channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) PublishBackupCompletion(nodeID string, backupID uuid.UUID) error {
|
||||
ctx := context.Background()
|
||||
|
||||
message := BackupCompletionMessage{
|
||||
NodeID: nodeID,
|
||||
BackupID: backupID.String(),
|
||||
}
|
||||
|
||||
messageJSON, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal backup completion message: %w", err)
|
||||
}
|
||||
|
||||
err = r.pubsubCompletions.Publish(ctx, backupCompletionChannel, string(messageJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to publish backup completion message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) SubscribeForBackupsCompletions(
|
||||
handler func(nodeID string, backupID uuid.UUID),
|
||||
) error {
|
||||
ctx := context.Background()
|
||||
|
||||
wrappedHandler := func(message string) {
|
||||
var msg BackupCompletionMessage
|
||||
if err := json.Unmarshal([]byte(message), &msg); err != nil {
|
||||
r.logger.Warn("Failed to unmarshal backup completion message", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
backupID, err := uuid.Parse(msg.BackupID)
|
||||
if err != nil {
|
||||
r.logger.Warn(
|
||||
"Failed to parse backup ID from completion message",
|
||||
"backupId",
|
||||
msg.BackupID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
handler(msg.NodeID, backupID)
|
||||
}
|
||||
|
||||
err := r.pubsubCompletions.Subscribe(ctx, backupCompletionChannel, wrappedHandler)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to backup completion channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Subscribed to backup completion channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) UnsubscribeForBackupsCompletions() error {
|
||||
err := r.pubsubCompletions.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unsubscribe from backup completion channel: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Unsubscribed from backup completion channel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) extractNodeIDFromKey(key, prefix, suffix string) uuid.UUID {
|
||||
nodeIDStr := strings.TrimPrefix(key, prefix)
|
||||
nodeIDStr = strings.TrimSuffix(nodeIDStr, suffix)
|
||||
|
||||
nodeID, err := uuid.Parse(nodeIDStr)
|
||||
if err != nil {
|
||||
r.logger.Warn("Failed to parse node ID from key", "key", key, "error", err)
|
||||
return uuid.Nil
|
||||
}
|
||||
|
||||
return nodeID
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) pipelineGetKeys(keys []string) (map[string][]byte, error) {
|
||||
if len(keys) == 0 {
|
||||
return make(map[string][]byte), nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
commands := make([]valkey.Completed, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
commands = append(commands, r.client.B().Get().Key(key).Build())
|
||||
}
|
||||
|
||||
results := r.client.DoMulti(ctx, commands...)
|
||||
|
||||
keyDataMap := make(map[string][]byte, len(keys))
|
||||
for i, result := range results {
|
||||
if result.Error() != nil {
|
||||
r.logger.Warn("Failed to get key in pipeline", "key", keys[i], "error", result.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
data, err := result.AsBytes()
|
||||
if err != nil {
|
||||
r.logger.Warn("Failed to parse key data in pipeline", "key", keys[i], "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
keyDataMap[keys[i]] = data
|
||||
}
|
||||
|
||||
return keyDataMap, nil
|
||||
}
|
||||
|
||||
func (r *BackupNodesRegistry) parseIntFromBytes(data []byte) (int64, error) {
|
||||
str := string(data)
|
||||
var count int64
|
||||
_, err := fmt.Sscanf(str, "%d", &count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to parse integer from bytes: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
@@ -0,0 +1,904 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/logger"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_HearthbeatNodeInRegistry_RegistersNodeWithTTL(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
defer cleanupTestNode(registry, node)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
nodes, err := registry.GetAvailableNodes()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, node.ID, nodes[0].ID)
|
||||
assert.Equal(t, node.ThroughputMBs, nodes[0].ThroughputMBs)
|
||||
}
|
||||
|
||||
func Test_UnregisterNodeFromRegistry_RemovesNodeAndCounter(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.UnregisterNodeFromRegistry(node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
nodes, err := registry.GetAvailableNodes()
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, nodes)
|
||||
|
||||
stats, err := registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, stats)
|
||||
}
|
||||
|
||||
func Test_GetAvailableNodes_ReturnsAllRegisteredNodes(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node1 := createTestBackupNode()
|
||||
node2 := createTestBackupNode()
|
||||
node3 := createTestBackupNode()
|
||||
defer cleanupTestNode(registry, node1)
|
||||
defer cleanupTestNode(registry, node2)
|
||||
defer cleanupTestNode(registry, node3)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
|
||||
assert.NoError(t, err)
|
||||
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
|
||||
assert.NoError(t, err)
|
||||
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
nodes, err := registry.GetAvailableNodes()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, nodes, 3)
|
||||
|
||||
nodeIDs := make(map[uuid.UUID]bool)
|
||||
for _, node := range nodes {
|
||||
nodeIDs[node.ID] = true
|
||||
}
|
||||
assert.True(t, nodeIDs[node1.ID])
|
||||
assert.True(t, nodeIDs[node2.ID])
|
||||
assert.True(t, nodeIDs[node3.ID])
|
||||
}
|
||||
|
||||
func Test_GetAvailableNodes_WhenNoNodesExist_ReturnsEmptySlice(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
|
||||
nodes, err := registry.GetAvailableNodes()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, nodes)
|
||||
assert.Empty(t, nodes)
|
||||
}
|
||||
|
||||
func Test_IncrementBackupsInProgress_IncrementsCounter(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
defer cleanupTestNode(registry, node)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err := registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, stats, 1)
|
||||
assert.Equal(t, node.ID, stats[0].ID)
|
||||
assert.Equal(t, 1, stats[0].ActiveBackups)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err = registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, stats, 1)
|
||||
assert.Equal(t, 2, stats[0].ActiveBackups)
|
||||
}
|
||||
|
||||
func Test_DecrementBackupsInProgress_DecrementsCounter(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
defer cleanupTestNode(registry, node)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node.ID.String())
|
||||
assert.NoError(t, err)
|
||||
err = registry.IncrementBackupsInProgress(node.ID.String())
|
||||
assert.NoError(t, err)
|
||||
err = registry.IncrementBackupsInProgress(node.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err := registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, stats[0].ActiveBackups)
|
||||
|
||||
err = registry.DecrementBackupsInProgress(node.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err = registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, stats[0].ActiveBackups)
|
||||
|
||||
err = registry.DecrementBackupsInProgress(node.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err = registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, stats[0].ActiveBackups)
|
||||
}
|
||||
|
||||
func Test_DecrementBackupsInProgress_WhenNegative_ResetsToZero(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
defer cleanupTestNode(registry, node)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.DecrementBackupsInProgress(node.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err := registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, stats, 1)
|
||||
assert.Equal(t, 0, stats[0].ActiveBackups)
|
||||
}
|
||||
|
||||
func Test_GetBackupNodesStats_ReturnsStatsForAllNodes(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node1 := createTestBackupNode()
|
||||
node2 := createTestBackupNode()
|
||||
node3 := createTestBackupNode()
|
||||
defer cleanupTestNode(registry, node1)
|
||||
defer cleanupTestNode(registry, node2)
|
||||
defer cleanupTestNode(registry, node3)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
|
||||
assert.NoError(t, err)
|
||||
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
|
||||
assert.NoError(t, err)
|
||||
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node1.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node2.ID.String())
|
||||
assert.NoError(t, err)
|
||||
err = registry.IncrementBackupsInProgress(node2.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node3.ID.String())
|
||||
assert.NoError(t, err)
|
||||
err = registry.IncrementBackupsInProgress(node3.ID.String())
|
||||
assert.NoError(t, err)
|
||||
err = registry.IncrementBackupsInProgress(node3.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err := registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, stats, 3)
|
||||
|
||||
statsMap := make(map[uuid.UUID]int)
|
||||
for _, stat := range stats {
|
||||
statsMap[stat.ID] = stat.ActiveBackups
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, statsMap[node1.ID])
|
||||
assert.Equal(t, 2, statsMap[node2.ID])
|
||||
assert.Equal(t, 3, statsMap[node3.ID])
|
||||
}
|
||||
|
||||
func Test_GetBackupNodesStats_WhenNoStats_ReturnsEmptySlice(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
|
||||
stats, err := registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, stats)
|
||||
assert.Empty(t, stats)
|
||||
}
|
||||
|
||||
func Test_MultipleNodes_RegisteredAndQueriedCorrectly(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node1 := createTestBackupNode()
|
||||
node1.ThroughputMBs = 50
|
||||
node2 := createTestBackupNode()
|
||||
node2.ThroughputMBs = 100
|
||||
node3 := createTestBackupNode()
|
||||
node3.ThroughputMBs = 150
|
||||
defer cleanupTestNode(registry, node1)
|
||||
defer cleanupTestNode(registry, node2)
|
||||
defer cleanupTestNode(registry, node3)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
|
||||
assert.NoError(t, err)
|
||||
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
|
||||
assert.NoError(t, err)
|
||||
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
nodes, err := registry.GetAvailableNodes()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, nodes, 3)
|
||||
|
||||
nodeMap := make(map[uuid.UUID]BackupNode)
|
||||
for _, node := range nodes {
|
||||
nodeMap[node.ID] = node
|
||||
}
|
||||
|
||||
assert.Equal(t, 50, nodeMap[node1.ID].ThroughputMBs)
|
||||
assert.Equal(t, 100, nodeMap[node2.ID].ThroughputMBs)
|
||||
assert.Equal(t, 150, nodeMap[node3.ID].ThroughputMBs)
|
||||
}
|
||||
|
||||
func Test_BackupCounters_TrackedSeparatelyPerNode(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node1 := createTestBackupNode()
|
||||
node2 := createTestBackupNode()
|
||||
defer cleanupTestNode(registry, node1)
|
||||
defer cleanupTestNode(registry, node2)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
|
||||
assert.NoError(t, err)
|
||||
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node1.ID.String())
|
||||
assert.NoError(t, err)
|
||||
err = registry.IncrementBackupsInProgress(node1.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.IncrementBackupsInProgress(node2.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err := registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, stats, 2)
|
||||
|
||||
statsMap := make(map[uuid.UUID]int)
|
||||
for _, stat := range stats {
|
||||
statsMap[stat.ID] = stat.ActiveBackups
|
||||
}
|
||||
|
||||
assert.Equal(t, 2, statsMap[node1.ID])
|
||||
assert.Equal(t, 1, statsMap[node2.ID])
|
||||
|
||||
err = registry.DecrementBackupsInProgress(node1.ID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
stats, err = registry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
|
||||
statsMap = make(map[uuid.UUID]int)
|
||||
for _, stat := range stats {
|
||||
statsMap[stat.ID] = stat.ActiveBackups
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, statsMap[node1.ID])
|
||||
assert.Equal(t, 1, statsMap[node2.ID])
|
||||
}
|
||||
|
||||
func Test_GetAvailableNodes_SkipsInvalidJsonData(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
defer cleanupTestNode(registry, node)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
defer cancel()
|
||||
|
||||
invalidKey := nodeInfoKeyPrefix + uuid.New().String() + nodeInfoKeySuffix
|
||||
registry.client.Do(
|
||||
ctx,
|
||||
registry.client.B().Set().Key(invalidKey).Value("invalid json data").Build(),
|
||||
)
|
||||
defer func() {
|
||||
cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), registry.timeout)
|
||||
defer cleanupCancel()
|
||||
registry.client.Do(cleanupCtx, registry.client.B().Del().Key(invalidKey).Build())
|
||||
}()
|
||||
|
||||
nodes, err := registry.GetAvailableNodes()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, node.ID, nodes[0].ID)
|
||||
}
|
||||
|
||||
func Test_PipelineGetKeys_HandlesEmptyKeysList(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
|
||||
keyDataMap, err := registry.pipelineGetKeys([]string{})
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, keyDataMap)
|
||||
assert.Empty(t, keyDataMap)
|
||||
}
|
||||
|
||||
func Test_HearthbeatNodeInRegistry_UpdatesLastHeartbeat(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
originalHeartbeat := node.LastHeartbeat
|
||||
defer cleanupTestNode(registry, node)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
|
||||
assert.NoError(t, err)
|
||||
|
||||
nodes, err := registry.GetAvailableNodes()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.True(t, nodes[0].LastHeartbeat.After(originalHeartbeat))
|
||||
}
|
||||
|
||||
func createTestRegistry() *BackupNodesRegistry {
|
||||
return &BackupNodesRegistry{
|
||||
cache_utils.GetValkeyClient(),
|
||||
logger.GetLogger(),
|
||||
cache_utils.DefaultCacheTimeout,
|
||||
cache_utils.NewPubSubManager(),
|
||||
cache_utils.NewPubSubManager(),
|
||||
}
|
||||
}
|
||||
|
||||
func createTestBackupNode() BackupNode {
|
||||
return BackupNode{
|
||||
ID: uuid.New(),
|
||||
ThroughputMBs: 100,
|
||||
LastHeartbeat: time.Now().UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
func cleanupTestNode(registry *BackupNodesRegistry, node BackupNode) {
|
||||
registry.UnregisterNodeFromRegistry(node)
|
||||
}
|
||||
|
||||
func Test_AssignBackupTonode_PublishesJsonMessageToChannel(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
backupID := uuid.New()
|
||||
|
||||
err := registry.AssignBackupToNode(node.ID.String(), backupID, true)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_SubscribeNodeForBackupsAssignment_ReceivesSubmittedBackupsForMatchingNode(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
backupID := uuid.New()
|
||||
defer registry.UnsubscribeNodeForBackupsAssignments()
|
||||
|
||||
receivedBackupID := make(chan uuid.UUID, 1)
|
||||
handler := func(id uuid.UUID, isCallNotifier bool) {
|
||||
receivedBackupID <- id
|
||||
}
|
||||
|
||||
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.AssignBackupToNode(node.ID.String(), backupID, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case received := <-receivedBackupID:
|
||||
assert.Equal(t, backupID, received)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Timeout waiting for backup message")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SubscribeNodeForBackupsAssignment_FiltersOutBackupsForDifferentNode(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node1 := createTestBackupNode()
|
||||
node2 := createTestBackupNode()
|
||||
backupID := uuid.New()
|
||||
defer registry.UnsubscribeNodeForBackupsAssignments()
|
||||
|
||||
receivedBackupID := make(chan uuid.UUID, 1)
|
||||
handler := func(id uuid.UUID, isCallNotifier bool) {
|
||||
receivedBackupID <- id
|
||||
}
|
||||
|
||||
err := registry.SubscribeNodeForBackupsAssignment(node1.ID.String(), handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.AssignBackupToNode(node2.ID.String(), backupID, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-receivedBackupID:
|
||||
t.Fatal("Should not receive backup for different node")
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SubscribeNodeForBackupsAssignment_ParsesJsonAndBackupIdCorrectly(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
backupID1 := uuid.New()
|
||||
backupID2 := uuid.New()
|
||||
defer registry.UnsubscribeNodeForBackupsAssignments()
|
||||
|
||||
receivedBackups := make(chan uuid.UUID, 2)
|
||||
handler := func(id uuid.UUID, isCallNotifier bool) {
|
||||
receivedBackups <- id
|
||||
}
|
||||
|
||||
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.AssignBackupToNode(node.ID.String(), backupID1, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.AssignBackupToNode(node.ID.String(), backupID2, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
received1 := <-receivedBackups
|
||||
received2 := <-receivedBackups
|
||||
|
||||
receivedIDs := []uuid.UUID{received1, received2}
|
||||
assert.Contains(t, receivedIDs, backupID1)
|
||||
assert.Contains(t, receivedIDs, backupID2)
|
||||
}
|
||||
|
||||
func Test_SubscribeNodeForBackupsAssignment_HandlesInvalidJson(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
defer registry.UnsubscribeNodeForBackupsAssignments()
|
||||
|
||||
receivedBackupID := make(chan uuid.UUID, 1)
|
||||
handler := func(id uuid.UUID, isCallNotifier bool) {
|
||||
receivedBackupID <- id
|
||||
}
|
||||
|
||||
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
ctx := context.Background()
|
||||
err = registry.pubsubBackups.Publish(ctx, "backup:submit", "invalid json")
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-receivedBackupID:
|
||||
t.Fatal("Should not receive backup for invalid JSON")
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func Test_UnsubscribeNodeForBackupsAssignments_StopsReceivingMessages(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
backupID1 := uuid.New()
|
||||
backupID2 := uuid.New()
|
||||
|
||||
receivedBackupID := make(chan uuid.UUID, 2)
|
||||
handler := func(id uuid.UUID, isCallNotifier bool) {
|
||||
receivedBackupID <- id
|
||||
}
|
||||
|
||||
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.AssignBackupToNode(node.ID.String(), backupID1, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
received := <-receivedBackupID
|
||||
assert.Equal(t, backupID1, received)
|
||||
|
||||
err = registry.UnsubscribeNodeForBackupsAssignments()
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.AssignBackupToNode(node.ID.String(), backupID2, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-receivedBackupID:
|
||||
t.Fatal("Should not receive backup after unsubscribe")
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SubscribeNodeForBackupsAssignment_WhenAlreadySubscribed_ReturnsError(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
defer registry.UnsubscribeNodeForBackupsAssignments()
|
||||
|
||||
handler := func(id uuid.UUID, isCallNotifier bool) {}
|
||||
|
||||
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "already subscribed")
|
||||
}
|
||||
|
||||
func Test_MultipleNodes_EachReceivesOnlyTheirBackups(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry1 := createTestRegistry()
|
||||
registry2 := createTestRegistry()
|
||||
registry3 := createTestRegistry()
|
||||
|
||||
node1 := createTestBackupNode()
|
||||
node2 := createTestBackupNode()
|
||||
node3 := createTestBackupNode()
|
||||
|
||||
backupID1 := uuid.New()
|
||||
backupID2 := uuid.New()
|
||||
backupID3 := uuid.New()
|
||||
|
||||
defer registry1.UnsubscribeNodeForBackupsAssignments()
|
||||
defer registry2.UnsubscribeNodeForBackupsAssignments()
|
||||
defer registry3.UnsubscribeNodeForBackupsAssignments()
|
||||
|
||||
receivedBackups1 := make(chan uuid.UUID, 3)
|
||||
receivedBackups2 := make(chan uuid.UUID, 3)
|
||||
receivedBackups3 := make(chan uuid.UUID, 3)
|
||||
|
||||
handler1 := func(id uuid.UUID, isCallNotifier bool) { receivedBackups1 <- id }
|
||||
handler2 := func(id uuid.UUID, isCallNotifier bool) { receivedBackups2 <- id }
|
||||
handler3 := func(id uuid.UUID, isCallNotifier bool) { receivedBackups3 <- id }
|
||||
|
||||
err := registry1.SubscribeNodeForBackupsAssignment(node1.ID.String(), handler1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry2.SubscribeNodeForBackupsAssignment(node2.ID.String(), handler2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry3.SubscribeNodeForBackupsAssignment(node3.ID.String(), handler3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
submitRegistry := createTestRegistry()
|
||||
err = submitRegistry.AssignBackupToNode(node1.ID.String(), backupID1, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = submitRegistry.AssignBackupToNode(node2.ID.String(), backupID2, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = submitRegistry.AssignBackupToNode(node3.ID.String(), backupID3, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case received := <-receivedBackups1:
|
||||
assert.Equal(t, backupID1, received)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Node 1 timeout waiting for backup message")
|
||||
}
|
||||
|
||||
select {
|
||||
case received := <-receivedBackups2:
|
||||
assert.Equal(t, backupID2, received)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Node 2 timeout waiting for backup message")
|
||||
}
|
||||
|
||||
select {
|
||||
case received := <-receivedBackups3:
|
||||
assert.Equal(t, backupID3, received)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Node 3 timeout waiting for backup message")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-receivedBackups1:
|
||||
t.Fatal("Node 1 should not receive additional backups")
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
}
|
||||
|
||||
select {
|
||||
case <-receivedBackups2:
|
||||
t.Fatal("Node 2 should not receive additional backups")
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
}
|
||||
|
||||
select {
|
||||
case <-receivedBackups3:
|
||||
t.Fatal("Node 3 should not receive additional backups")
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func Test_PublishBackupCompletion_PublishesMessageToChannel(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
backupID := uuid.New()
|
||||
|
||||
err := registry.PublishBackupCompletion(node.ID.String(), backupID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_SubscribeForBackupsCompletions_ReceivesCompletedBackups(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
backupID := uuid.New()
|
||||
defer registry.UnsubscribeForBackupsCompletions()
|
||||
|
||||
receivedBackupID := make(chan uuid.UUID, 1)
|
||||
receivedNodeID := make(chan string, 1)
|
||||
handler := func(nodeID string, backupID uuid.UUID) {
|
||||
receivedNodeID <- nodeID
|
||||
receivedBackupID <- backupID
|
||||
}
|
||||
|
||||
err := registry.SubscribeForBackupsCompletions(handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.PublishBackupCompletion(node.ID.String(), backupID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case receivedNode := <-receivedNodeID:
|
||||
assert.Equal(t, node.ID.String(), receivedNode)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Timeout waiting for node ID")
|
||||
}
|
||||
|
||||
select {
|
||||
case received := <-receivedBackupID:
|
||||
assert.Equal(t, backupID, received)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Timeout waiting for backup completion message")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SubscribeForBackupsCompletions_ParsesJsonCorrectly(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
backupID1 := uuid.New()
|
||||
backupID2 := uuid.New()
|
||||
defer registry.UnsubscribeForBackupsCompletions()
|
||||
|
||||
receivedBackups := make(chan uuid.UUID, 2)
|
||||
handler := func(nodeID string, backupID uuid.UUID) {
|
||||
receivedBackups <- backupID
|
||||
}
|
||||
|
||||
err := registry.SubscribeForBackupsCompletions(handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.PublishBackupCompletion(node.ID.String(), backupID1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.PublishBackupCompletion(node.ID.String(), backupID2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
received1 := <-receivedBackups
|
||||
received2 := <-receivedBackups
|
||||
|
||||
receivedIDs := []uuid.UUID{received1, received2}
|
||||
assert.Contains(t, receivedIDs, backupID1)
|
||||
assert.Contains(t, receivedIDs, backupID2)
|
||||
}
|
||||
|
||||
func Test_SubscribeForBackupsCompletions_HandlesInvalidJson(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
defer registry.UnsubscribeForBackupsCompletions()
|
||||
|
||||
receivedBackupID := make(chan uuid.UUID, 1)
|
||||
handler := func(nodeID string, backupID uuid.UUID) {
|
||||
receivedBackupID <- backupID
|
||||
}
|
||||
|
||||
err := registry.SubscribeForBackupsCompletions(handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
ctx := context.Background()
|
||||
err = registry.pubsubCompletions.Publish(ctx, "backup:completion", "invalid json")
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-receivedBackupID:
|
||||
t.Fatal("Should not receive backup for invalid JSON")
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func Test_UnsubscribeForBackupsCompletions_StopsReceivingMessages(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
node := createTestBackupNode()
|
||||
backupID1 := uuid.New()
|
||||
backupID2 := uuid.New()
|
||||
|
||||
receivedBackupID := make(chan uuid.UUID, 2)
|
||||
handler := func(nodeID string, backupID uuid.UUID) {
|
||||
receivedBackupID <- backupID
|
||||
}
|
||||
|
||||
err := registry.SubscribeForBackupsCompletions(handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.PublishBackupCompletion(node.ID.String(), backupID1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
received := <-receivedBackupID
|
||||
assert.Equal(t, backupID1, received)
|
||||
|
||||
err = registry.UnsubscribeForBackupsCompletions()
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = registry.PublishBackupCompletion(node.ID.String(), backupID2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-receivedBackupID:
|
||||
t.Fatal("Should not receive backup after unsubscribe")
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SubscribeForBackupsCompletions_WhenAlreadySubscribed_ReturnsError(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry := createTestRegistry()
|
||||
defer registry.UnsubscribeForBackupsCompletions()
|
||||
|
||||
handler := func(nodeID string, backupID uuid.UUID) {}
|
||||
|
||||
err := registry.SubscribeForBackupsCompletions(handler)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry.SubscribeForBackupsCompletions(handler)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "already subscribed")
|
||||
}
|
||||
|
||||
func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
registry1 := createTestRegistry()
|
||||
registry2 := createTestRegistry()
|
||||
registry3 := createTestRegistry()
|
||||
|
||||
node1 := createTestBackupNode()
|
||||
node2 := createTestBackupNode()
|
||||
node3 := createTestBackupNode()
|
||||
|
||||
backupID1 := uuid.New()
|
||||
backupID2 := uuid.New()
|
||||
backupID3 := uuid.New()
|
||||
|
||||
defer registry1.UnsubscribeForBackupsCompletions()
|
||||
defer registry2.UnsubscribeForBackupsCompletions()
|
||||
defer registry3.UnsubscribeForBackupsCompletions()
|
||||
|
||||
receivedBackups1 := make(chan uuid.UUID, 3)
|
||||
receivedBackups2 := make(chan uuid.UUID, 3)
|
||||
receivedBackups3 := make(chan uuid.UUID, 3)
|
||||
|
||||
handler1 := func(nodeID string, backupID uuid.UUID) { receivedBackups1 <- backupID }
|
||||
handler2 := func(nodeID string, backupID uuid.UUID) { receivedBackups2 <- backupID }
|
||||
handler3 := func(nodeID string, backupID uuid.UUID) { receivedBackups3 <- backupID }
|
||||
|
||||
err := registry1.SubscribeForBackupsCompletions(handler1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry2.SubscribeForBackupsCompletions(handler2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = registry3.SubscribeForBackupsCompletions(handler3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
publishRegistry := createTestRegistry()
|
||||
err = publishRegistry.PublishBackupCompletion(node1.ID.String(), backupID1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = publishRegistry.PublishBackupCompletion(node2.ID.String(), backupID2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = publishRegistry.PublishBackupCompletion(node3.ID.String(), backupID3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
receivedAll1 := []uuid.UUID{}
|
||||
receivedAll2 := []uuid.UUID{}
|
||||
receivedAll3 := []uuid.UUID{}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
select {
|
||||
case received := <-receivedBackups1:
|
||||
receivedAll1 = append(receivedAll1, received)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Subscriber 1 timeout waiting for completion message")
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
select {
|
||||
case received := <-receivedBackups2:
|
||||
receivedAll2 = append(receivedAll2, received)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Subscriber 2 timeout waiting for completion message")
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
select {
|
||||
case received := <-receivedBackups3:
|
||||
receivedAll3 = append(receivedAll3, received)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Subscriber 3 timeout waiting for completion message")
|
||||
}
|
||||
}
|
||||
|
||||
assert.Contains(t, receivedAll1, backupID1)
|
||||
assert.Contains(t, receivedAll1, backupID2)
|
||||
assert.Contains(t, receivedAll1, backupID3)
|
||||
|
||||
assert.Contains(t, receivedAll2, backupID1)
|
||||
assert.Contains(t, receivedAll2, backupID2)
|
||||
assert.Contains(t, receivedAll2, backupID3)
|
||||
|
||||
assert.Contains(t, receivedAll3, backupID1)
|
||||
assert.Contains(t, receivedAll3, backupID2)
|
||||
assert.Contains(t, receivedAll3, backupID3)
|
||||
}
|
||||
600
backend/internal/features/backups/backups/backuping/scheduler.go
Normal file
600
backend/internal/features/backups/backups/backuping/scheduler.go
Normal file
@@ -0,0 +1,600 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"databasus-backend/internal/config"
|
||||
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/storages"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/period"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type BackupsScheduler struct {
|
||||
backupRepository *backups_core.BackupRepository
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
storageService *storages.StorageService
|
||||
backupCancelManager *backups_cancellation.BackupCancelManager
|
||||
nodesRegistry *BackupNodesRegistry
|
||||
|
||||
lastBackupTime time.Time
|
||||
logger *slog.Logger
|
||||
|
||||
backupToNodeRelations map[uuid.UUID]BackupToNodeRelation
|
||||
backuperNode *BackuperNode
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) Run(ctx context.Context) {
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
|
||||
if config.GetEnv().IsManyNodesMode {
|
||||
// wait other nodes to start
|
||||
time.Sleep(1 * time.Minute)
|
||||
}
|
||||
|
||||
if err := s.failBackupsInProgress(); err != nil {
|
||||
s.logger.Error("Failed to fail backups in progress", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if err := s.nodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted); err != nil {
|
||||
s.logger.Error("Failed to subscribe to backup completions", "error", err)
|
||||
panic(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := s.nodesRegistry.UnsubscribeForBackupsCompletions(); err != nil {
|
||||
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.cleanOldBackups(); err != nil {
|
||||
s.logger.Error("Failed to clean old backups", "error", err)
|
||||
}
|
||||
|
||||
if err := s.checkDeadNodesAndFailBackups(); err != nil {
|
||||
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
|
||||
}
|
||||
|
||||
if err := s.runPendingBackups(); err != nil {
|
||||
s.logger.Error("Failed to run pending backups", "error", err)
|
||||
}
|
||||
|
||||
s.lastBackupTime = time.Now().UTC()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) IsSchedulerRunning() bool {
|
||||
// if last backup time is more than 5 minutes ago, return false
|
||||
return s.lastBackupTime.After(time.Now().UTC().Add(-5 * time.Minute))
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) failBackupsInProgress() error {
|
||||
backupsInProgress, err := s.backupRepository.FindByStatus(backups_core.BackupStatusInProgress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Println("Backups in progress", len(backupsInProgress))
|
||||
|
||||
for _, backup := range backupsInProgress {
|
||||
if err := s.backupCancelManager.CancelBackup(backup.ID); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to cancel backup via context manager",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(backup.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
failMessage := "Backup failed due to application restart"
|
||||
backup.FailMessage = &failMessage
|
||||
backup.Status = backups_core.BackupStatusFailed
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
s.backuperNode.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupFailed,
|
||||
&failMessage,
|
||||
)
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool) {
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(databaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if backupConfig.StorageID == nil {
|
||||
s.logger.Error("Backup config storage ID is nil", "databaseId", databaseID)
|
||||
return
|
||||
}
|
||||
|
||||
leastBusyNodeID, err := s.calculateLeastBusyNode()
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to calculate least busy node",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
backup := &backups_core.Backup{
|
||||
DatabaseID: backupConfig.DatabaseID,
|
||||
StorageID: *backupConfig.StorageID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
BackupSizeMb: 0,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to save backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.nodesRegistry.IncrementBackupsInProgress(leastBusyNodeID.String()); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to increment backups in progress",
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.nodesRegistry.AssignBackupToNode(leastBusyNodeID.String(), backup.ID, isCallNotifier); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to submit backup",
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
if decrementErr := s.nodesRegistry.DecrementBackupsInProgress(leastBusyNodeID.String()); decrementErr != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement backups in progress after submit failure",
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
"error",
|
||||
decrementErr,
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if relation, exists := s.backupToNodeRelations[*leastBusyNodeID]; exists {
|
||||
relation.BackupsIDs = append(relation.BackupsIDs, backup.ID)
|
||||
s.backupToNodeRelations[*leastBusyNodeID] = relation
|
||||
} else {
|
||||
s.backupToNodeRelations[*leastBusyNodeID] = BackupToNodeRelation{
|
||||
NodeID: *leastBusyNodeID,
|
||||
BackupsIDs: []uuid.UUID{backup.ID},
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Info(
|
||||
"Successfully triggered scheduled backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"nodeId",
|
||||
leastBusyNodeID,
|
||||
)
|
||||
}
|
||||
|
||||
// GetRemainedBackupTryCount returns the number of remaining backup tries for a given backup.
|
||||
// If the backup is not failed or the backup config does not allow retries, it returns 0.
|
||||
// If the backup is failed and the backup config allows retries, it returns the number of remaining tries.
|
||||
// If the backup is failed and the backup config does not allow retries, it returns 0.
|
||||
func (s *BackupsScheduler) GetRemainedBackupTryCount(lastBackup *backups_core.Backup) int {
|
||||
if lastBackup == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
if lastBackup.Status != backups_core.BackupStatusFailed {
|
||||
return 0
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(lastBackup.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
if !backupConfig.IsRetryIfFailed {
|
||||
return 0
|
||||
}
|
||||
|
||||
maxFailedTriesCount := backupConfig.MaxFailedTriesCount
|
||||
|
||||
lastBackups, err := s.backupRepository.FindByDatabaseIDWithLimit(
|
||||
lastBackup.DatabaseID,
|
||||
maxFailedTriesCount,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to find last backups by database ID", "error", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
lastFailedBackups := make([]*backups_core.Backup, 0)
|
||||
|
||||
for _, backup := range lastBackups {
|
||||
if backup.Status == backups_core.BackupStatusFailed {
|
||||
lastFailedBackups = append(lastFailedBackups, backup)
|
||||
}
|
||||
}
|
||||
|
||||
return maxFailedTriesCount - len(lastFailedBackups)
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) cleanOldBackups() error {
|
||||
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
backupStorePeriod := backupConfig.StorePeriod
|
||||
|
||||
if backupStorePeriod == period.PeriodForever {
|
||||
continue
|
||||
}
|
||||
|
||||
storeDuration := backupStorePeriod.ToDuration()
|
||||
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
|
||||
|
||||
oldBackups, err := s.backupRepository.FindBackupsBeforeDate(
|
||||
backupConfig.DatabaseID,
|
||||
dateBeforeBackupsShouldBeDeleted,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to find old backups for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, backup := range oldBackups {
|
||||
storage, err := s.storageService.GetStorageByID(backup.StorageID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to get storage by ID",
|
||||
"storageId",
|
||||
backup.StorageID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
err = storage.DeleteFile(encryptor, backup.ID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to delete backup file", "backupId", backup.ID, "error", err)
|
||||
}
|
||||
|
||||
if err := s.backupRepository.DeleteByID(backup.ID); err != nil {
|
||||
s.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
s.logger.Info(
|
||||
"Deleted old backup",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) runPendingBackups() error {
|
||||
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, backupConfig := range enabledBackupConfigs {
|
||||
if backupConfig.BackupInterval == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
lastBackup, err := s.backupRepository.FindLastByDatabaseID(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to get last backup for database",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
var lastBackupTime *time.Time
|
||||
if lastBackup != nil {
|
||||
lastBackupTime = &lastBackup.CreatedAt
|
||||
}
|
||||
|
||||
remainedBackupTryCount := s.GetRemainedBackupTryCount(lastBackup)
|
||||
|
||||
if backupConfig.BackupInterval.ShouldTriggerBackup(time.Now().UTC(), lastBackupTime) ||
|
||||
remainedBackupTryCount > 0 {
|
||||
s.logger.Info(
|
||||
"Triggering scheduled backup",
|
||||
"databaseId",
|
||||
backupConfig.DatabaseID,
|
||||
"intervalType",
|
||||
backupConfig.BackupInterval.Interval,
|
||||
)
|
||||
|
||||
s.StartBackup(backupConfig.DatabaseID, remainedBackupTryCount == 1)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
|
||||
nodes, err := s.nodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get available nodes: %w", err)
|
||||
}
|
||||
|
||||
if len(nodes) == 0 {
|
||||
return nil, fmt.Errorf("no nodes available")
|
||||
}
|
||||
|
||||
stats, err := s.nodesRegistry.GetBackupNodesStats()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get backup nodes stats: %w", err)
|
||||
}
|
||||
|
||||
statsMap := make(map[uuid.UUID]int)
|
||||
for _, stat := range stats {
|
||||
statsMap[stat.ID] = stat.ActiveBackups
|
||||
}
|
||||
|
||||
var bestNode *BackupNode
|
||||
var bestScore float64 = -1
|
||||
|
||||
now := time.Now().UTC()
|
||||
for i := range nodes {
|
||||
node := &nodes[i]
|
||||
|
||||
if now.Sub(node.LastHeartbeat) > 2*time.Minute {
|
||||
continue
|
||||
}
|
||||
|
||||
activeBackups := statsMap[node.ID]
|
||||
|
||||
var score float64
|
||||
if node.ThroughputMBs > 0 {
|
||||
score = float64(activeBackups) / float64(node.ThroughputMBs)
|
||||
} else {
|
||||
score = float64(activeBackups) * 1000
|
||||
}
|
||||
|
||||
if bestNode == nil || score < bestScore {
|
||||
bestNode = node
|
||||
bestScore = score
|
||||
}
|
||||
}
|
||||
|
||||
if bestNode == nil {
|
||||
return nil, fmt.Errorf("no suitable nodes available")
|
||||
}
|
||||
|
||||
return &bestNode.ID, nil
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUID) {
|
||||
nodeID, err := uuid.Parse(nodeIDStr)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to parse node ID from completion message",
|
||||
"nodeId",
|
||||
nodeIDStr,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
relation, exists := s.backupToNodeRelations[nodeID]
|
||||
if !exists {
|
||||
s.logger.Warn(
|
||||
"Received completion for unknown node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
newBackupIDs := make([]uuid.UUID, 0)
|
||||
found := false
|
||||
for _, id := range relation.BackupsIDs {
|
||||
if id == backupID {
|
||||
found = true
|
||||
continue
|
||||
}
|
||||
newBackupIDs = append(newBackupIDs, id)
|
||||
}
|
||||
|
||||
if !found {
|
||||
s.logger.Warn(
|
||||
"Backup not found in node's backup list",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if len(newBackupIDs) == 0 {
|
||||
delete(s.backupToNodeRelations, nodeID)
|
||||
} else {
|
||||
relation.BackupsIDs = newBackupIDs
|
||||
s.backupToNodeRelations[nodeID] = relation
|
||||
}
|
||||
|
||||
if err := s.nodesRegistry.DecrementBackupsInProgress(nodeIDStr); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement backups in progress",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) checkDeadNodesAndFailBackups() error {
|
||||
nodes, err := s.nodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get available nodes: %w", err)
|
||||
}
|
||||
|
||||
aliveNodeIDs := make(map[uuid.UUID]bool)
|
||||
now := time.Now().UTC()
|
||||
|
||||
for _, node := range nodes {
|
||||
if now.Sub(node.LastHeartbeat) <= 2*time.Minute {
|
||||
aliveNodeIDs[node.ID] = true
|
||||
}
|
||||
}
|
||||
|
||||
for nodeID, relation := range s.backupToNodeRelations {
|
||||
if aliveNodeIDs[nodeID] {
|
||||
continue
|
||||
}
|
||||
|
||||
s.logger.Warn(
|
||||
"Node is dead, failing its backups",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupCount",
|
||||
len(relation.BackupsIDs),
|
||||
)
|
||||
|
||||
for _, backupID := range relation.BackupsIDs {
|
||||
backup, err := s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to find backup for dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
failMessage := "Backup failed due to node unavailability"
|
||||
backup.FailMessage = &failMessage
|
||||
backup.Status = backups_core.BackupStatusFailed
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to save failed backup for dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := s.nodesRegistry.DecrementBackupsInProgress(nodeID.String()); err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to decrement backups in progress for dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
s.logger.Info(
|
||||
"Failed backup due to dead node",
|
||||
"nodeId",
|
||||
nodeID,
|
||||
"backupId",
|
||||
backupID,
|
||||
)
|
||||
}
|
||||
|
||||
delete(s.backupToNodeRelations, nodeID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package backups
|
||||
package backuping
|
||||
|
||||
import (
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
@@ -9,14 +10,21 @@ import (
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/period"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
|
||||
func Test_RunPendingBackups_WhenLastBackupWasYesterday_CreatesNewBackup(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
@@ -57,16 +65,16 @@ func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add old backup
|
||||
backupRepository.Save(&Backup{
|
||||
backupRepository.Save(&backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusCompleted,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
GetBackupsScheduler().runPendingBackups()
|
||||
|
||||
// Wait for backup to complete (runs in goroutine)
|
||||
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
|
||||
@@ -80,7 +88,12 @@ func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
|
||||
func Test_RunPendingBackups_WhenLastBackupWasRecentlyCompleted_SkipsBackup(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
@@ -121,16 +134,16 @@ func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add recent backup (1 hour ago)
|
||||
backupRepository.Save(&Backup{
|
||||
backupRepository.Save(&backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusCompleted,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
GetBackupsScheduler().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
@@ -143,7 +156,12 @@ func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T) {
|
||||
func Test_RunPendingBackups_WhenLastBackupFailedAndRetriesDisabled_SkipsBackup(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
@@ -187,17 +205,17 @@ func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T)
|
||||
|
||||
// add failed backup
|
||||
failMessage := "backup failed"
|
||||
backupRepository.Save(&Backup{
|
||||
backupRepository.Save(&backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusFailed,
|
||||
Status: backups_core.BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
GetBackupsScheduler().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
@@ -210,7 +228,12 @@ func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
|
||||
func Test_RunPendingBackups_WhenLastBackupFailedAndRetriesEnabled_CreatesNewBackup(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
@@ -254,17 +277,17 @@ func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
|
||||
|
||||
// add failed backup
|
||||
failMessage := "backup failed"
|
||||
backupRepository.Save(&Backup{
|
||||
backupRepository.Save(&backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusFailed,
|
||||
Status: backups_core.BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
GetBackupsScheduler().runPendingBackups()
|
||||
|
||||
// Wait for backup to complete (runs in goroutine)
|
||||
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
|
||||
@@ -278,7 +301,12 @@ func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *testing.T) {
|
||||
func Test_RunPendingBackups_WhenFailedBackupsExceedMaxRetries_SkipsBackup(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
// setup data
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
@@ -322,19 +350,19 @@ func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *tes
|
||||
|
||||
failMessage := "backup failed"
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
backupRepository.Save(&Backup{
|
||||
for range 3 {
|
||||
backupRepository.Save(&backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusFailed,
|
||||
Status: backups_core.BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
})
|
||||
}
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
GetBackupsScheduler().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
@@ -347,7 +375,12 @@ func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *tes
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_MakeBackgroundBackupWhenBakupsDisabled_BackupSkipped(t *testing.T) {
|
||||
func Test_RunPendingBackups_WhenBackupsDisabled_SkipsBackup(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
backuperNode := CreateTestBackuperNode()
|
||||
cancel := StartBackuperNodeForTest(t, backuperNode)
|
||||
defer StopBackuperNodeForTest(t, cancel, backuperNode)
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
@@ -385,16 +418,16 @@ func Test_MakeBackgroundBackupWhenBakupsDisabled_BackupSkipped(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add old backup that would trigger new backup if enabled
|
||||
backupRepository.Save(&Backup{
|
||||
backupRepository.Save(&backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusCompleted,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
|
||||
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
|
||||
})
|
||||
|
||||
GetBackupBackgroundService().runPendingBackups()
|
||||
GetBackupsScheduler().runPendingBackups()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
@@ -405,3 +438,272 @@ func Test_MakeBackgroundBackupWhenBakupsDisabled_BackupSkipped(t *testing.T) {
|
||||
// Wait for any cleanup operations to complete before defer cleanup runs
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegistry(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Register mock node without subscribing to backups (simulates node crash after registration)
|
||||
mockNodeID := uuid.New()
|
||||
err = CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC())
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Scheduler assigns backup to mock node
|
||||
GetBackupsScheduler().StartBackup(database.ID, false)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1)
|
||||
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status)
|
||||
|
||||
// Verify Valkey counter was incremented when backup was assigned
|
||||
stats, err := nodesRegistry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
foundStat := false
|
||||
for _, stat := range stats {
|
||||
if stat.ID == mockNodeID {
|
||||
assert.Equal(t, 1, stat.ActiveBackups)
|
||||
foundStat = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, foundStat, "Node stats should be present")
|
||||
|
||||
// Simulate node death by setting heartbeat older than 2-minute threshold
|
||||
oldHeartbeat := time.Now().UTC().Add(-3 * time.Minute)
|
||||
err = UpdateNodeHeartbeatDirectly(mockNodeID, 100, oldHeartbeat)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Trigger dead node detection
|
||||
err = GetBackupsScheduler().checkDeadNodesAndFailBackups()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify backup was failed with appropriate error message
|
||||
backups, err = backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 1)
|
||||
assert.Equal(t, backups_core.BackupStatusFailed, backups[0].Status)
|
||||
assert.NotNil(t, backups[0].FailMessage)
|
||||
assert.Contains(t, *backups[0].FailMessage, "node unavailability")
|
||||
|
||||
// Verify Valkey counter was decremented after backup failed
|
||||
stats, err = nodesRegistry.GetBackupNodesStats()
|
||||
assert.NoError(t, err)
|
||||
for _, stat := range stats {
|
||||
if stat.ID == mockNodeID {
|
||||
assert.Equal(t, 0, stat.ActiveBackups)
|
||||
}
|
||||
}
|
||||
|
||||
// Node info should still exist in registry (not removed by checkDeadNodesAndFailBackups)
|
||||
node, err := GetNodeFromRegistry(mockNodeID)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, node)
|
||||
assert.Equal(t, mockNodeID, node.ID)
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
|
||||
t.Run("Nodes with same throughput", func(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
node1ID := uuid.New()
|
||||
node2ID := uuid.New()
|
||||
node3ID := uuid.New()
|
||||
now := time.Now().UTC()
|
||||
|
||||
err := CreateMockNodeInRegistry(node1ID, 100, now)
|
||||
assert.NoError(t, err)
|
||||
err = CreateMockNodeInRegistry(node2ID, 100, now)
|
||||
assert.NoError(t, err)
|
||||
err = CreateMockNodeInRegistry(node3ID, 100, now)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for range 5 {
|
||||
err = nodesRegistry.IncrementBackupsInProgress(node1ID.String())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
for range 2 {
|
||||
err = nodesRegistry.IncrementBackupsInProgress(node2ID.String())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
for range 8 {
|
||||
err = nodesRegistry.IncrementBackupsInProgress(node3ID.String())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
leastBusyNodeID, err := GetBackupsScheduler().calculateLeastBusyNode()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, leastBusyNodeID)
|
||||
assert.Equal(t, node2ID, *leastBusyNodeID)
|
||||
})
|
||||
|
||||
t.Run("Nodes with different throughput", func(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
node100MBsID := uuid.New()
|
||||
node50MBsID := uuid.New()
|
||||
now := time.Now().UTC()
|
||||
|
||||
err := CreateMockNodeInRegistry(node100MBsID, 100, now)
|
||||
assert.NoError(t, err)
|
||||
err = CreateMockNodeInRegistry(node50MBsID, 50, now)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for range 10 {
|
||||
err = nodesRegistry.IncrementBackupsInProgress(node100MBsID.String())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
err = nodesRegistry.IncrementBackupsInProgress(node50MBsID.String())
|
||||
assert.NoError(t, err)
|
||||
|
||||
leastBusyNodeID, err := GetBackupsScheduler().calculateLeastBusyNode()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, leastBusyNodeID)
|
||||
assert.Equal(t, node50MBsID, *leastBusyNodeID)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_FailBackupsInProgress_WhenSchedulerStarts_CancelsBackupsAndUpdatesStatus(t *testing.T) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
defer func() {
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
backupConfig.BackupInterval = &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
}
|
||||
backupConfig.IsBackupsEnabled = true
|
||||
backupConfig.StorePeriod = period.PeriodWeek
|
||||
backupConfig.Storage = storage
|
||||
backupConfig.StorageID = &storage.ID
|
||||
|
||||
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create two in-progress backups that should be failed on scheduler restart
|
||||
backup1 := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
BackupSizeMb: 10.5,
|
||||
CreatedAt: time.Now().UTC().Add(-30 * time.Minute),
|
||||
}
|
||||
err = backupRepository.Save(backup1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backup2 := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
BackupSizeMb: 5.2,
|
||||
CreatedAt: time.Now().UTC().Add(-15 * time.Minute),
|
||||
}
|
||||
err = backupRepository.Save(backup2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create a completed backup to verify it's not affected by failBackupsInProgress
|
||||
completedBackup := &backups_core.Backup{
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 20.0,
|
||||
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
|
||||
}
|
||||
err = backupRepository.Save(completedBackup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Trigger the scheduler's failBackupsInProgress logic
|
||||
// This should cancel in-progress backups and mark them as failed
|
||||
err = GetBackupsScheduler().failBackupsInProgress()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify all backups exist and were processed correctly
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, backups, 3)
|
||||
|
||||
var failedCount int
|
||||
var completedCount int
|
||||
for _, backup := range backups {
|
||||
switch backup.Status {
|
||||
case backups_core.BackupStatusFailed:
|
||||
failedCount++
|
||||
// Verify fail message indicates application restart
|
||||
assert.NotNil(t, backup.FailMessage)
|
||||
assert.Equal(t, "Backup failed due to application restart", *backup.FailMessage)
|
||||
// Verify backup size was reset to 0
|
||||
assert.Equal(t, float64(0), backup.BackupSizeMb)
|
||||
case backups_core.BackupStatusCompleted:
|
||||
completedCount++
|
||||
}
|
||||
}
|
||||
|
||||
// Verify correct number of backups in each state
|
||||
assert.Equal(t, 2, failedCount)
|
||||
assert.Equal(t, 1, completedCount)
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
206
backend/internal/features/backups/backups/backuping/testing.go
Normal file
206
backend/internal/features/backups/backups/backuping/testing.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package backuping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/usecases"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
"databasus-backend/internal/util/logger"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func CreateTestRouter() *gin.Engine {
|
||||
router := workspaces_testing.CreateTestRouter(
|
||||
workspaces_controllers.GetWorkspaceController(),
|
||||
workspaces_controllers.GetMembershipController(),
|
||||
databases.GetDatabaseController(),
|
||||
backups_config.GetBackupConfigController(),
|
||||
)
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
func CreateTestBackuperNode() *BackuperNode {
|
||||
return &BackuperNode{
|
||||
databases.GetDatabaseService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
backupCancelManager,
|
||||
nodesRegistry,
|
||||
logger.GetLogger(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
uuid.New(),
|
||||
time.Time{},
|
||||
}
|
||||
}
|
||||
|
||||
// WaitForBackupCompletion waits for a new backup to be created and completed (or failed)
|
||||
// for the given database. It checks for backups with count greater than expectedInitialCount.
|
||||
func WaitForBackupCompletion(
|
||||
t *testing.T,
|
||||
databaseID uuid.UUID,
|
||||
expectedInitialCount int,
|
||||
timeout time.Duration,
|
||||
) {
|
||||
deadline := time.Now().UTC().Add(timeout)
|
||||
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
backups, err := backupRepository.FindByDatabaseID(databaseID)
|
||||
if err != nil {
|
||||
t.Logf("WaitForBackupCompletion: error finding backups: %v", err)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
t.Logf(
|
||||
"WaitForBackupCompletion: found %d backups (expected > %d)",
|
||||
len(backups),
|
||||
expectedInitialCount,
|
||||
)
|
||||
|
||||
if len(backups) > expectedInitialCount {
|
||||
// Check if the newest backup has completed or failed
|
||||
newestBackup := backups[0]
|
||||
t.Logf("WaitForBackupCompletion: newest backup status: %s", newestBackup.Status)
|
||||
|
||||
if newestBackup.Status == backups_core.BackupStatusCompleted ||
|
||||
newestBackup.Status == backups_core.BackupStatusFailed ||
|
||||
newestBackup.Status == backups_core.BackupStatusCanceled {
|
||||
t.Logf(
|
||||
"WaitForBackupCompletion: backup finished with status %s",
|
||||
newestBackup.Status,
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Logf("WaitForBackupCompletion: timeout waiting for backup to complete")
|
||||
}
|
||||
|
||||
// StartBackuperNodeForTest starts a BackuperNode in a goroutine for testing.
|
||||
// The node registers itself in the registry and subscribes to backup assignments.
|
||||
// Returns a context cancel function that should be deferred to stop the node.
|
||||
func StartBackuperNodeForTest(t *testing.T, backuperNode *BackuperNode) context.CancelFunc {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
backuperNode.Run(ctx)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Poll registry for node presence instead of fixed sleep
|
||||
deadline := time.Now().UTC().Add(5 * time.Second)
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
nodes, err := nodesRegistry.GetAvailableNodes()
|
||||
if err == nil {
|
||||
for _, node := range nodes {
|
||||
if node.ID == backuperNode.nodeID {
|
||||
t.Logf("BackuperNode registered in registry: %s", backuperNode.nodeID)
|
||||
|
||||
return func() {
|
||||
cancel()
|
||||
select {
|
||||
case <-done:
|
||||
t.Log("BackuperNode stopped gracefully")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Log("BackuperNode stop timeout")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Fatalf("BackuperNode failed to register in registry within timeout")
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopBackuperNodeForTest stops the BackuperNode by canceling its context.
|
||||
// It waits for the node to unregister from the registry.
|
||||
func StopBackuperNodeForTest(t *testing.T, cancel context.CancelFunc, backuperNode *BackuperNode) {
|
||||
cancel()
|
||||
|
||||
// Wait for node to unregister from registry
|
||||
deadline := time.Now().UTC().Add(2 * time.Second)
|
||||
for time.Now().UTC().Before(deadline) {
|
||||
nodes, err := nodesRegistry.GetAvailableNodes()
|
||||
if err == nil {
|
||||
found := false
|
||||
for _, node := range nodes {
|
||||
if node.ID == backuperNode.nodeID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Logf("BackuperNode unregistered from registry: %s", backuperNode.nodeID)
|
||||
return
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Logf("BackuperNode stop completed for %s", backuperNode.nodeID)
|
||||
}
|
||||
|
||||
func CreateMockNodeInRegistry(nodeID uuid.UUID, throughputMBs int, lastHeartbeat time.Time) error {
|
||||
backupNode := BackupNode{
|
||||
ID: nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: lastHeartbeat,
|
||||
}
|
||||
|
||||
return nodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
|
||||
}
|
||||
|
||||
func UpdateNodeHeartbeatDirectly(
|
||||
nodeID uuid.UUID,
|
||||
throughputMBs int,
|
||||
lastHeartbeat time.Time,
|
||||
) error {
|
||||
backupNode := BackupNode{
|
||||
ID: nodeID,
|
||||
ThroughputMBs: throughputMBs,
|
||||
LastHeartbeat: lastHeartbeat,
|
||||
}
|
||||
|
||||
return nodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
|
||||
}
|
||||
|
||||
func GetNodeFromRegistry(nodeID uuid.UUID) (*BackupNode, error) {
|
||||
nodes, err := nodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
if node.ID == nodeID {
|
||||
return &node, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("node not found")
|
||||
}
|
||||
@@ -1,9 +1,8 @@
|
||||
package backups
|
||||
package backups_cancellation
|
||||
|
||||
import (
|
||||
"context"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/logger"
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
@@ -12,22 +11,14 @@ import (
|
||||
|
||||
const backupCancelChannel = "backup:cancel"
|
||||
|
||||
type BackupContextManager struct {
|
||||
type BackupCancelManager struct {
|
||||
mu sync.RWMutex
|
||||
cancelFuncs map[uuid.UUID]context.CancelFunc
|
||||
pubsub *cache_utils.PubSubManager
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewBackupContextManager() *BackupContextManager {
|
||||
return &BackupContextManager{
|
||||
cancelFuncs: make(map[uuid.UUID]context.CancelFunc),
|
||||
pubsub: cache_utils.NewPubSubManager(),
|
||||
logger: logger.GetLogger(),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) StartSubscription() {
|
||||
func (m *BackupCancelManager) StartSubscription() {
|
||||
ctx := context.Background()
|
||||
|
||||
handler := func(message string) {
|
||||
@@ -56,14 +47,14 @@ func (m *BackupContextManager) StartSubscription() {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) RegisterBackup(backupID uuid.UUID, cancelFunc context.CancelFunc) {
|
||||
func (m *BackupCancelManager) RegisterBackup(backupID uuid.UUID, cancelFunc context.CancelFunc) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.cancelFuncs[backupID] = cancelFunc
|
||||
m.logger.Debug("Registered backup", "backupID", backupID)
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) CancelBackup(backupID uuid.UUID) error {
|
||||
func (m *BackupCancelManager) CancelBackup(backupID uuid.UUID) error {
|
||||
ctx := context.Background()
|
||||
|
||||
err := m.pubsub.Publish(ctx, backupCancelChannel, backupID.String())
|
||||
@@ -76,7 +67,7 @@ func (m *BackupContextManager) CancelBackup(backupID uuid.UUID) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) UnregisterBackup(backupID uuid.UUID) {
|
||||
func (m *BackupCancelManager) UnregisterBackup(backupID uuid.UUID) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.cancelFuncs, backupID)
|
||||
@@ -1,4 +1,4 @@
|
||||
package backups
|
||||
package backups_cancellation
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
)
|
||||
|
||||
func Test_RegisterBackup_BackupRegisteredSuccessfully(t *testing.T) {
|
||||
manager := NewBackupContextManager()
|
||||
manager := backupCancelManager
|
||||
|
||||
backupID := uuid.New()
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
@@ -26,7 +26,7 @@ func Test_RegisterBackup_BackupRegisteredSuccessfully(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_UnregisterBackup_BackupUnregisteredSuccessfully(t *testing.T) {
|
||||
manager := NewBackupContextManager()
|
||||
manager := backupCancelManager
|
||||
|
||||
backupID := uuid.New()
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
@@ -42,7 +42,7 @@ func Test_UnregisterBackup_BackupUnregisteredSuccessfully(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_CancelBackup_OnSameInstance_BackupCancelledViaPubSub(t *testing.T) {
|
||||
manager := NewBackupContextManager()
|
||||
manager := backupCancelManager
|
||||
|
||||
backupID := uuid.New()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@@ -75,8 +75,8 @@ func Test_CancelBackup_OnSameInstance_BackupCancelledViaPubSub(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_CancelBackup_FromDifferentInstance_BackupCancelledOnRunningInstance(t *testing.T) {
|
||||
manager1 := NewBackupContextManager()
|
||||
manager2 := NewBackupContextManager()
|
||||
manager1 := backupCancelManager
|
||||
manager2 := backupCancelManager
|
||||
|
||||
backupID := uuid.New()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@@ -111,7 +111,7 @@ func Test_CancelBackup_FromDifferentInstance_BackupCancelledOnRunningInstance(t
|
||||
}
|
||||
|
||||
func Test_CancelBackup_WhenBackupDoesNotExist_NoErrorReturned(t *testing.T) {
|
||||
manager := NewBackupContextManager()
|
||||
manager := backupCancelManager
|
||||
|
||||
manager.StartSubscription()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
@@ -122,7 +122,7 @@ func Test_CancelBackup_WhenBackupDoesNotExist_NoErrorReturned(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_CancelBackup_WithMultipleBackups_AllBackupsCancelled(t *testing.T) {
|
||||
manager := NewBackupContextManager()
|
||||
manager := backupCancelManager
|
||||
|
||||
numBackups := 5
|
||||
backupIDs := make([]uuid.UUID, numBackups)
|
||||
@@ -165,7 +165,7 @@ func Test_CancelBackup_WithMultipleBackups_AllBackupsCancelled(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_CancelBackup_AfterUnregister_BackupNotCancelled(t *testing.T) {
|
||||
manager := NewBackupContextManager()
|
||||
manager := backupCancelManager
|
||||
|
||||
backupID := uuid.New()
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
25
backend/internal/features/backups/backups/cancellation/di.go
Normal file
25
backend/internal/features/backups/backups/cancellation/di.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package backups_cancellation
|
||||
|
||||
import (
|
||||
"context"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/logger"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var backupCancelManager = &BackupCancelManager{
|
||||
sync.RWMutex{},
|
||||
make(map[uuid.UUID]context.CancelFunc),
|
||||
cache_utils.NewPubSubManager(),
|
||||
logger.GetLogger(),
|
||||
}
|
||||
|
||||
func GetBackupCancelManager() *BackupCancelManager {
|
||||
return backupCancelManager
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
backupCancelManager.StartSubscription()
|
||||
}
|
||||
@@ -1,11 +1,15 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"context"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
"databasus-backend/internal/features/databases"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -170,9 +174,10 @@ func (c *BackupController) CancelBackup(ctx *gin.Context) {
|
||||
// @Description Generate a token for downloading a backup file (valid for 5 minutes)
|
||||
// @Tags backups
|
||||
// @Param id path string true "Backup ID"
|
||||
// @Success 200 {object} GenerateDownloadTokenResponse
|
||||
// @Success 200 {object} backups_download.GenerateDownloadTokenResponse
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 409 {object} map[string]string "Download already in progress"
|
||||
// @Router /backups/{id}/download-token [post]
|
||||
func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
@@ -189,6 +194,15 @@ func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) {
|
||||
|
||||
response, err := c.backupService.GenerateDownloadToken(user, id)
|
||||
if err != nil {
|
||||
if err == backups_download.ErrDownloadAlreadyInProgress {
|
||||
ctx.JSON(
|
||||
http.StatusConflict,
|
||||
gin.H{
|
||||
"error": "Download already in progress for some of backups. Please wait until previous download completed or cancel it",
|
||||
},
|
||||
)
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -198,14 +212,22 @@ func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) {
|
||||
|
||||
// GetFile
|
||||
// @Summary Download a backup file
|
||||
// @Description Download the backup file for the specified backup using a download token
|
||||
// @Description Download the backup file for the specified backup using a download token.
|
||||
// @Description
|
||||
// @Description **Download Concurrency Control:**
|
||||
// @Description - Only one download per user is allowed at a time
|
||||
// @Description - If a download is already in progress, returns 409 Conflict
|
||||
// @Description - Downloads are tracked using cache with 5-second TTL and 3-second heartbeat
|
||||
// @Description - Browser cancellations automatically release the download lock
|
||||
// @Description - Server crashes are handled via automatic cache expiry (5 seconds)
|
||||
// @Tags backups
|
||||
// @Param id path string true "Backup ID"
|
||||
// @Param token query string true "Download token"
|
||||
// @Success 200 {file} file
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 500
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 409 {object} map[string]string "Download already in progress"
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /backups/{id}/file [get]
|
||||
func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
token := ctx.Query("token")
|
||||
@@ -214,7 +236,6 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Get backup ID from URL
|
||||
backupIDParam := ctx.Param("id")
|
||||
backupID, err := uuid.Parse(backupIDParam)
|
||||
if err != nil {
|
||||
@@ -222,13 +243,22 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
downloadToken, err := c.backupService.ValidateDownloadToken(token)
|
||||
downloadToken, rateLimiter, err := c.backupService.ValidateDownloadToken(token)
|
||||
if err != nil {
|
||||
if err == backups_download.ErrDownloadAlreadyInProgress {
|
||||
ctx.JSON(
|
||||
http.StatusConflict,
|
||||
gin.H{
|
||||
"error": "download already in progress for this user. Please wait until previous download completed or cancel it",
|
||||
},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired download token"})
|
||||
return
|
||||
}
|
||||
|
||||
// Verify token is for the requested backup
|
||||
if downloadToken.BackupID != backupID {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired download token"})
|
||||
return
|
||||
@@ -238,18 +268,28 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
downloadToken.BackupID,
|
||||
)
|
||||
if err != nil {
|
||||
c.backupService.UnregisterDownload(downloadToken.UserID)
|
||||
c.backupService.ReleaseDownloadLock(downloadToken.UserID)
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
rateLimitedReader := backups_download.NewRateLimitedReader(fileReader, rateLimiter)
|
||||
|
||||
heartbeatCtx, cancelHeartbeat := context.WithCancel(context.Background())
|
||||
defer func() {
|
||||
if err := fileReader.Close(); err != nil {
|
||||
cancelHeartbeat()
|
||||
c.backupService.UnregisterDownload(downloadToken.UserID)
|
||||
c.backupService.ReleaseDownloadLock(downloadToken.UserID)
|
||||
if err := rateLimitedReader.Close(); err != nil {
|
||||
fmt.Printf("Error closing file reader: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go c.startDownloadHeartbeat(heartbeatCtx, downloadToken.UserID)
|
||||
|
||||
filename := c.generateBackupFilename(backup, database)
|
||||
|
||||
// Set Content-Length for progress tracking
|
||||
if backup.BackupSizeMb > 0 {
|
||||
sizeBytes := int64(backup.BackupSizeMb * 1024 * 1024)
|
||||
ctx.Header("Content-Length", fmt.Sprintf("%d", sizeBytes))
|
||||
@@ -261,13 +301,12 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
fmt.Sprintf("attachment; filename=\"%s\"", filename),
|
||||
)
|
||||
|
||||
_, err = io.Copy(ctx.Writer, fileReader)
|
||||
_, err = io.Copy(ctx.Writer, rateLimitedReader)
|
||||
if err != nil {
|
||||
fmt.Printf("Error streaming file: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Write audit log after successful download
|
||||
c.backupService.WriteAuditLogForDownload(downloadToken.UserID, backup, database)
|
||||
}
|
||||
|
||||
@@ -276,7 +315,7 @@ type MakeBackupRequest struct {
|
||||
}
|
||||
|
||||
func (c *BackupController) generateBackupFilename(
|
||||
backup *Backup,
|
||||
backup *backups_core.Backup,
|
||||
database *databases.Database,
|
||||
) string {
|
||||
// Format timestamp as YYYY-MM-DD_HH-mm-ss
|
||||
@@ -333,3 +372,17 @@ func sanitizeFilename(name string) string {
|
||||
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func (c *BackupController) startDownloadHeartbeat(ctx context.Context, userID uuid.UUID) {
|
||||
ticker := time.NewTicker(backups_download.GetDownloadHeartbeatInterval())
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
c.backupService.RefreshDownloadLock(userID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,7 +18,8 @@ import (
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups/download_token"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
@@ -478,7 +479,7 @@ func Test_GenerateDownloadToken_PermissionsEnforced(t *testing.T) {
|
||||
)
|
||||
|
||||
if tt.expectSuccess {
|
||||
var response GenerateDownloadTokenResponse
|
||||
var response backups_download.GenerateDownloadTokenResponse
|
||||
err := json.Unmarshal(testResp.Body, &response)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, response.Token)
|
||||
@@ -499,7 +500,7 @@ func Test_DownloadBackup_WithValidToken_Success(t *testing.T) {
|
||||
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
// Generate download token
|
||||
var tokenResponse GenerateDownloadTokenResponse
|
||||
var tokenResponse backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
@@ -620,7 +621,7 @@ func Test_DownloadBackup_TokenUsedOnce_CannotReuseToken(t *testing.T) {
|
||||
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
// Generate download token
|
||||
var tokenResponse GenerateDownloadTokenResponse
|
||||
var tokenResponse backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
@@ -683,7 +684,7 @@ func Test_DownloadBackup_WithDifferentBackupToken_Unauthorized(t *testing.T) {
|
||||
backup2 := createTestBackup(database2, owner)
|
||||
|
||||
// Generate token for backup1
|
||||
var tokenResponse GenerateDownloadTokenResponse
|
||||
var tokenResponse backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
@@ -714,7 +715,7 @@ func Test_DownloadBackup_AuditLogWritten(t *testing.T) {
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
// Generate download token
|
||||
var tokenResponse GenerateDownloadTokenResponse
|
||||
var tokenResponse backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
@@ -806,7 +807,7 @@ func Test_DownloadBackup_ProperFilenameForPostgreSQL(t *testing.T) {
|
||||
backup := createTestBackup(database, owner)
|
||||
|
||||
// Generate download token
|
||||
var tokenResponse GenerateDownloadTokenResponse
|
||||
var tokenResponse backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
@@ -897,22 +898,22 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
|
||||
_, err = configService.SaveBackupConfig(config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backup := &Backup{
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: BackupStatusInProgress,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
BackupSizeMb: 0,
|
||||
BackupDurationMs: 0,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
repo := &BackupRepository{}
|
||||
repo := &backups_core.BackupRepository{}
|
||||
err = repo.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Register a cancellable context for the backup
|
||||
GetBackupService().backupContextManager.RegisterBackup(backup.ID, func() {})
|
||||
GetBackupService().backupCancelManager.RegisterBackup(backup.ID, func() {})
|
||||
|
||||
resp := test_utils.MakePostRequest(
|
||||
t,
|
||||
@@ -949,6 +950,189 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
|
||||
assert.True(t, foundCancelLog, "Cancel audit log should be created")
|
||||
}
|
||||
|
||||
func Test_ConcurrentDownloadPrevention(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
var token1Response backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token1Response,
|
||||
)
|
||||
|
||||
var token2Response backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token2Response,
|
||||
)
|
||||
|
||||
downloadInProgress := make(chan bool, 1)
|
||||
downloadComplete := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf(
|
||||
"/api/v1/backups/%s/file?token=%s",
|
||||
backup.ID.String(),
|
||||
token1Response.Token,
|
||||
),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
downloadComplete <- true
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
service := GetBackupService()
|
||||
if !service.IsDownloadInProgress(owner.UserID) {
|
||||
t.Log("Warning: First download completed before we could test concurrency")
|
||||
<-downloadComplete
|
||||
return
|
||||
}
|
||||
|
||||
downloadInProgress <- true
|
||||
|
||||
resp := test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), token2Response.Token),
|
||||
"",
|
||||
http.StatusConflict,
|
||||
)
|
||||
|
||||
var errorResponse map[string]string
|
||||
err := json.Unmarshal(resp.Body, &errorResponse)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, errorResponse["error"], "download already in progress")
|
||||
|
||||
<-downloadComplete
|
||||
<-downloadInProgress
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
var token3Response backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token3Response,
|
||||
)
|
||||
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), token3Response.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
|
||||
t.Log("Database:", database.Name)
|
||||
t.Log(
|
||||
"Successfully prevented concurrent downloads and allowed subsequent downloads after completion",
|
||||
)
|
||||
}
|
||||
|
||||
func Test_GenerateDownloadToken_BlockedWhenDownloadInProgress(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
var token1Response backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token1Response,
|
||||
)
|
||||
|
||||
downloadComplete := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf(
|
||||
"/api/v1/backups/%s/file?token=%s",
|
||||
backup.ID.String(),
|
||||
token1Response.Token,
|
||||
),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
downloadComplete <- true
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
service := GetBackupService()
|
||||
if !service.IsDownloadInProgress(owner.UserID) {
|
||||
t.Log("Warning: First download completed before we could test token generation blocking")
|
||||
<-downloadComplete
|
||||
return
|
||||
}
|
||||
|
||||
resp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusConflict,
|
||||
)
|
||||
|
||||
var errorResponse map[string]string
|
||||
err := json.Unmarshal(resp.Body, &errorResponse)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, errorResponse["error"], "download already in progress")
|
||||
|
||||
<-downloadComplete
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
var token2Response backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token2Response,
|
||||
)
|
||||
|
||||
assert.NotEmpty(t, token2Response.Token)
|
||||
assert.NotEqual(t, token1Response.Token, token2Response.Token)
|
||||
|
||||
t.Log("Database:", database.Name)
|
||||
t.Log(
|
||||
"Successfully blocked token generation during download and allowed generation after completion",
|
||||
)
|
||||
}
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
return CreateTestRouter()
|
||||
}
|
||||
@@ -1038,7 +1222,7 @@ func createTestDatabaseWithBackups(
|
||||
workspace *workspaces_models.Workspace,
|
||||
owner *users_dto.SignInResponseDTO,
|
||||
router *gin.Engine,
|
||||
) (*databases.Database, *Backup) {
|
||||
) (*databases.Database, *backups_core.Backup) {
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
@@ -1064,7 +1248,7 @@ func createTestDatabaseWithBackups(
|
||||
func createTestBackup(
|
||||
database *databases.Database,
|
||||
owner *users_dto.SignInResponseDTO,
|
||||
) *Backup {
|
||||
) *backups_core.Backup {
|
||||
userService := users_services.GetUserService()
|
||||
user, err := userService.GetUserFromToken(owner.Token)
|
||||
if err != nil {
|
||||
@@ -1076,17 +1260,17 @@ func createTestBackup(
|
||||
panic("No storage found for workspace")
|
||||
}
|
||||
|
||||
backup := &Backup{
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storages[0].ID,
|
||||
Status: BackupStatusCompleted,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10.5,
|
||||
BackupDurationMs: 1000,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
repo := &BackupRepository{}
|
||||
repo := &backups_core.BackupRepository{}
|
||||
if err := repo.Save(backup); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@@ -1116,7 +1300,7 @@ func createExpiredDownloadToken(backupID, userID uuid.UUID) string {
|
||||
}
|
||||
|
||||
// Manually update the token to be expired
|
||||
repo := &download_token.DownloadTokenRepository{}
|
||||
repo := &backups_download.DownloadTokenRepository{}
|
||||
downloadToken, err := repo.FindByToken(token)
|
||||
if err != nil || downloadToken == nil {
|
||||
panic(fmt.Sprintf("Failed to find generated token: %v", err))
|
||||
@@ -1130,3 +1314,267 @@ func createExpiredDownloadToken(backupID, userID uuid.UUID) string {
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
func Test_BandwidthThrottling_SingleDownload_Uses75Percent(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
bandwidthManager := backups_download.GetBandwidthManager()
|
||||
initialCount := bandwidthManager.GetActiveDownloadCount()
|
||||
|
||||
var tokenResponse backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&tokenResponse,
|
||||
)
|
||||
|
||||
downloadStarted := make(chan bool, 1)
|
||||
downloadComplete := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf(
|
||||
"/api/v1/backups/%s/file?token=%s",
|
||||
backup.ID.String(),
|
||||
tokenResponse.Token,
|
||||
),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
downloadComplete <- true
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
activeCount := bandwidthManager.GetActiveDownloadCount()
|
||||
if activeCount > initialCount {
|
||||
downloadStarted <- true
|
||||
assert.Equal(t, initialCount+1, activeCount, "Should have one active download")
|
||||
}
|
||||
|
||||
<-downloadComplete
|
||||
if len(downloadStarted) > 0 {
|
||||
<-downloadStarted
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
finalCount := bandwidthManager.GetActiveDownloadCount()
|
||||
assert.Equal(t, initialCount, finalCount, "Download should be unregistered after completion")
|
||||
}
|
||||
|
||||
func Test_BandwidthThrottling_MultipleDownloads_ShareBandwidth(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
owner3 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner1, router)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
owner2,
|
||||
users_enums.WorkspaceRoleMember,
|
||||
owner1.Token,
|
||||
router,
|
||||
)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
owner3,
|
||||
users_enums.WorkspaceRoleMember,
|
||||
owner1.Token,
|
||||
router,
|
||||
)
|
||||
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner1.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
configService := backups_config.GetBackupConfigService()
|
||||
config, err := configService.GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
config.IsBackupsEnabled = true
|
||||
config.StorageID = &storage.ID
|
||||
config.Storage = storage
|
||||
_, err = configService.SaveBackupConfig(config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backup1 := createTestBackup(database, owner1)
|
||||
backup2 := createTestBackup(database, owner2)
|
||||
backup3 := createTestBackup(database, owner3)
|
||||
|
||||
var token1, token2, token3 backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup1.ID.String()),
|
||||
"Bearer "+owner1.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token1,
|
||||
)
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup2.ID.String()),
|
||||
"Bearer "+owner2.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token2,
|
||||
)
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup3.ID.String()),
|
||||
"Bearer "+owner3.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token3,
|
||||
)
|
||||
|
||||
bandwidthManager := backups_download.GetBandwidthManager()
|
||||
initialCount := bandwidthManager.GetActiveDownloadCount()
|
||||
|
||||
complete1 := make(chan bool, 1)
|
||||
complete2 := make(chan bool, 1)
|
||||
complete3 := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup1.ID.String(), token1.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
complete1 <- true
|
||||
}()
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup2.ID.String(), token2.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
complete2 <- true
|
||||
}()
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup3.ID.String(), token3.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
complete3 <- true
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
<-complete1
|
||||
<-complete2
|
||||
<-complete3
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
finalCount := bandwidthManager.GetActiveDownloadCount()
|
||||
assert.Equal(t, initialCount, finalCount, "All downloads should be unregistered")
|
||||
}
|
||||
|
||||
func Test_BandwidthThrottling_DynamicAdjustment(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner1, router)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
owner2,
|
||||
users_enums.WorkspaceRoleMember,
|
||||
owner1.Token,
|
||||
router,
|
||||
)
|
||||
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner1.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
configService := backups_config.GetBackupConfigService()
|
||||
config, err := configService.GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
config.IsBackupsEnabled = true
|
||||
config.StorageID = &storage.ID
|
||||
config.Storage = storage
|
||||
_, err = configService.SaveBackupConfig(config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backup1 := createTestBackup(database, owner1)
|
||||
backup2 := createTestBackup(database, owner2)
|
||||
|
||||
var token1, token2 backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup1.ID.String()),
|
||||
"Bearer "+owner1.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token1,
|
||||
)
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup2.ID.String()),
|
||||
"Bearer "+owner2.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token2,
|
||||
)
|
||||
|
||||
bandwidthManager := backups_download.GetBandwidthManager()
|
||||
initialCount := bandwidthManager.GetActiveDownloadCount()
|
||||
|
||||
complete1 := make(chan bool, 1)
|
||||
complete2 := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup1.ID.String(), token1.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
complete1 <- true
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup2.ID.String(), token2.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
complete2 <- true
|
||||
}()
|
||||
|
||||
<-complete1
|
||||
<-complete2
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
finalCount := bandwidthManager.GetActiveDownloadCount()
|
||||
assert.Equal(t, initialCount, finalCount, "All downloads completed and unregistered")
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package backups
|
||||
package backups_core
|
||||
|
||||
type BackupStatus string
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package backups
|
||||
package backups_core
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -1,4 +1,4 @@
|
||||
package backups
|
||||
package backups_core
|
||||
|
||||
import (
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
@@ -1,4 +1,4 @@
|
||||
package backups
|
||||
package backups_core
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/storage"
|
||||
@@ -1,10 +1,11 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups/download_token"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
"databasus-backend/internal/features/backups/backups/usecases"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -16,50 +17,31 @@ import (
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var backupRepository = &BackupRepository{}
|
||||
var backupRepository = &backups_core.BackupRepository{}
|
||||
|
||||
var backupContextManager = NewBackupContextManager()
|
||||
var backupCancelManager = backups_cancellation.GetBackupCancelManager()
|
||||
|
||||
var backupService = &BackupService{
|
||||
databases.GetDatabaseService(),
|
||||
storages.GetStorageService(),
|
||||
backupRepository,
|
||||
notifiers.GetNotifierService(),
|
||||
notifiers.GetNotifierService(),
|
||||
backups_config.GetBackupConfigService(),
|
||||
encryption_secrets.GetSecretKeyService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
audit_logs.GetAuditLogService(),
|
||||
backupContextManager,
|
||||
download_token.GetDownloadTokenService(),
|
||||
}
|
||||
|
||||
var backupBackgroundService = &BackupBackgroundService{
|
||||
backupService,
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
time.Now().UTC(),
|
||||
logger.GetLogger(),
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
backupRepository: backupRepository,
|
||||
notifierService: notifiers.GetNotifierService(),
|
||||
notificationSender: notifiers.GetNotifierService(),
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
secretKeyService: encryption_secrets.GetSecretKeyService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
createBackupUseCase: usecases.GetCreateBackupUsecase(),
|
||||
logger: logger.GetLogger(),
|
||||
backupRemoveListeners: []backups_core.BackupRemoveListener{},
|
||||
workspaceService: workspaces_services.GetWorkspaceService(),
|
||||
auditLogService: audit_logs.GetAuditLogService(),
|
||||
backupCancelManager: backupCancelManager,
|
||||
downloadTokenService: backups_download.GetDownloadTokenService(),
|
||||
backupSchedulerService: backuping.GetBackupsScheduler(),
|
||||
}
|
||||
|
||||
var backupController = &BackupController{
|
||||
backupService,
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
backups_config.
|
||||
GetBackupConfigService().
|
||||
SetDatabaseStorageChangeListener(backupService)
|
||||
|
||||
databases.GetDatabaseService().AddDbRemoveListener(backupService)
|
||||
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
|
||||
|
||||
backupContextManager.StartSubscription()
|
||||
backupService: backupService,
|
||||
}
|
||||
|
||||
func GetBackupService() *BackupService {
|
||||
@@ -70,10 +52,11 @@ func GetBackupController() *BackupController {
|
||||
return backupController
|
||||
}
|
||||
|
||||
func GetBackupBackgroundService() *BackupBackgroundService {
|
||||
return backupBackgroundService
|
||||
}
|
||||
func SetupDependencies() {
|
||||
backups_config.
|
||||
GetBackupConfigService().
|
||||
SetDatabaseStorageChangeListener(backupService)
|
||||
|
||||
func GetDownloadTokenBackgroundService() *download_token.DownloadTokenBackgroundService {
|
||||
return download_token.GetDownloadTokenBackgroundService()
|
||||
databases.GetDatabaseService().AddDbRemoveListener(backupService)
|
||||
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
|
||||
}
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DownloadTokenBackgroundService struct {
|
||||
downloadTokenService *DownloadTokenService
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *DownloadTokenBackgroundService) Run(ctx context.Context) {
|
||||
s.logger.Info("Starting download token cleanup background service")
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
|
||||
s.logger.Error("Failed to clean expired download tokens", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type BandwidthManager struct {
|
||||
mu sync.RWMutex
|
||||
activeDownloads map[uuid.UUID]*activeDownload
|
||||
maxTotalBytesPerSecond int64
|
||||
bytesPerSecondPerDownload int64
|
||||
}
|
||||
|
||||
type activeDownload struct {
|
||||
userID uuid.UUID
|
||||
rateLimiter *RateLimiter
|
||||
}
|
||||
|
||||
func NewBandwidthManager(throughputMBs int) *BandwidthManager {
|
||||
// Use 75% of total throughput
|
||||
maxBytes := int64(throughputMBs) * 1024 * 1024 * 75 / 100
|
||||
|
||||
return &BandwidthManager{
|
||||
activeDownloads: make(map[uuid.UUID]*activeDownload),
|
||||
maxTotalBytesPerSecond: maxBytes,
|
||||
bytesPerSecondPerDownload: maxBytes,
|
||||
}
|
||||
}
|
||||
|
||||
func (bm *BandwidthManager) RegisterDownload(userID uuid.UUID) (*RateLimiter, error) {
|
||||
bm.mu.Lock()
|
||||
defer bm.mu.Unlock()
|
||||
|
||||
if _, exists := bm.activeDownloads[userID]; exists {
|
||||
return nil, fmt.Errorf("download already registered for user %s", userID)
|
||||
}
|
||||
|
||||
rateLimiter := NewRateLimiter(bm.bytesPerSecondPerDownload)
|
||||
|
||||
bm.activeDownloads[userID] = &activeDownload{
|
||||
userID: userID,
|
||||
rateLimiter: rateLimiter,
|
||||
}
|
||||
|
||||
bm.recalculateRates()
|
||||
|
||||
return rateLimiter, nil
|
||||
}
|
||||
|
||||
func (bm *BandwidthManager) UnregisterDownload(userID uuid.UUID) {
|
||||
bm.mu.Lock()
|
||||
defer bm.mu.Unlock()
|
||||
|
||||
delete(bm.activeDownloads, userID)
|
||||
bm.recalculateRates()
|
||||
}
|
||||
|
||||
func (bm *BandwidthManager) GetActiveDownloadCount() int {
|
||||
bm.mu.RLock()
|
||||
defer bm.mu.RUnlock()
|
||||
return len(bm.activeDownloads)
|
||||
}
|
||||
|
||||
func (bm *BandwidthManager) recalculateRates() {
|
||||
activeCount := len(bm.activeDownloads)
|
||||
|
||||
if activeCount == 0 {
|
||||
bm.bytesPerSecondPerDownload = bm.maxTotalBytesPerSecond
|
||||
return
|
||||
}
|
||||
|
||||
newRate := bm.maxTotalBytesPerSecond / int64(activeCount)
|
||||
bm.bytesPerSecondPerDownload = newRate
|
||||
|
||||
for _, download := range bm.activeDownloads {
|
||||
download.rateLimiter.UpdateRate(newRate)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,150 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_BandwidthManager_RegisterSingleDownload(t *testing.T) {
|
||||
throughputMBs := 100
|
||||
manager := NewBandwidthManager(throughputMBs)
|
||||
|
||||
expectedBytesPerSec := int64(100 * 1024 * 1024 * 75 / 100)
|
||||
assert.Equal(t, expectedBytesPerSec, manager.maxTotalBytesPerSecond)
|
||||
assert.Equal(t, expectedBytesPerSec, manager.bytesPerSecondPerDownload)
|
||||
|
||||
userID := uuid.New()
|
||||
rateLimiter, err := manager.RegisterDownload(userID)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, rateLimiter)
|
||||
|
||||
assert.Equal(t, 1, manager.GetActiveDownloadCount())
|
||||
assert.Equal(t, expectedBytesPerSec, manager.bytesPerSecondPerDownload)
|
||||
assert.Equal(t, expectedBytesPerSec, rateLimiter.bytesPerSecond)
|
||||
}
|
||||
|
||||
func Test_BandwidthManager_RegisterMultipleDownloads_BandwidthShared(t *testing.T) {
|
||||
throughputMBs := 100
|
||||
manager := NewBandwidthManager(throughputMBs)
|
||||
|
||||
maxBytes := int64(100 * 1024 * 1024 * 75 / 100)
|
||||
|
||||
user1 := uuid.New()
|
||||
rateLimiter1, err := manager.RegisterDownload(user1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, maxBytes, rateLimiter1.bytesPerSecond)
|
||||
|
||||
user2 := uuid.New()
|
||||
rateLimiter2, err := manager.RegisterDownload(user2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
expectedPerDownload := maxBytes / 2
|
||||
assert.Equal(t, expectedPerDownload, manager.bytesPerSecondPerDownload)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter2.bytesPerSecond)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter2.bytesPerSecond)
|
||||
|
||||
user3 := uuid.New()
|
||||
rateLimiter3, err := manager.RegisterDownload(user3)
|
||||
assert.NoError(t, err)
|
||||
|
||||
expectedPerDownload = maxBytes / 3
|
||||
assert.Equal(t, expectedPerDownload, manager.bytesPerSecondPerDownload)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter2.bytesPerSecond)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter3.bytesPerSecond)
|
||||
assert.Equal(t, 3, manager.GetActiveDownloadCount())
|
||||
}
|
||||
|
||||
func Test_BandwidthManager_UnregisterDownload_BandwidthRebalanced(t *testing.T) {
|
||||
throughputMBs := 100
|
||||
manager := NewBandwidthManager(throughputMBs)
|
||||
|
||||
maxBytes := int64(100 * 1024 * 1024 * 75 / 100)
|
||||
|
||||
user1 := uuid.New()
|
||||
rateLimiter1, _ := manager.RegisterDownload(user1)
|
||||
|
||||
user2 := uuid.New()
|
||||
_, _ = manager.RegisterDownload(user2)
|
||||
|
||||
user3 := uuid.New()
|
||||
rateLimiter3, _ := manager.RegisterDownload(user3)
|
||||
|
||||
assert.Equal(t, 3, manager.GetActiveDownloadCount())
|
||||
expectedPerDownload := maxBytes / 3
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
|
||||
|
||||
manager.UnregisterDownload(user2)
|
||||
|
||||
assert.Equal(t, 2, manager.GetActiveDownloadCount())
|
||||
expectedPerDownload = maxBytes / 2
|
||||
assert.Equal(t, expectedPerDownload, manager.bytesPerSecondPerDownload)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
|
||||
assert.Equal(t, expectedPerDownload, rateLimiter3.bytesPerSecond)
|
||||
|
||||
manager.UnregisterDownload(user1)
|
||||
|
||||
assert.Equal(t, 1, manager.GetActiveDownloadCount())
|
||||
assert.Equal(t, maxBytes, manager.bytesPerSecondPerDownload)
|
||||
assert.Equal(t, maxBytes, rateLimiter3.bytesPerSecond)
|
||||
|
||||
manager.UnregisterDownload(user3)
|
||||
assert.Equal(t, 0, manager.GetActiveDownloadCount())
|
||||
assert.Equal(t, maxBytes, manager.bytesPerSecondPerDownload)
|
||||
}
|
||||
|
||||
func Test_BandwidthManager_RegisterDuplicateUser_ReturnsError(t *testing.T) {
|
||||
manager := NewBandwidthManager(100)
|
||||
|
||||
userID := uuid.New()
|
||||
_, err := manager.RegisterDownload(userID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = manager.RegisterDownload(userID)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "download already registered")
|
||||
}
|
||||
|
||||
func Test_RateLimiter_TokenBucketBasic(t *testing.T) {
|
||||
bytesPerSec := int64(1024 * 1024)
|
||||
limiter := NewRateLimiter(bytesPerSec)
|
||||
|
||||
assert.Equal(t, bytesPerSec, limiter.bytesPerSecond)
|
||||
assert.Equal(t, bytesPerSec*2, limiter.bucketSize)
|
||||
|
||||
start := time.Now()
|
||||
limiter.Wait(512 * 1024)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.Less(t, elapsed, 100*time.Millisecond)
|
||||
}
|
||||
|
||||
func Test_RateLimiter_UpdateRate(t *testing.T) {
|
||||
limiter := NewRateLimiter(1024 * 1024)
|
||||
|
||||
assert.Equal(t, int64(1024*1024), limiter.bytesPerSecond)
|
||||
|
||||
newRate := int64(2 * 1024 * 1024)
|
||||
limiter.UpdateRate(newRate)
|
||||
|
||||
assert.Equal(t, newRate, limiter.bytesPerSecond)
|
||||
assert.Equal(t, newRate*2, limiter.bucketSize)
|
||||
}
|
||||
|
||||
func Test_RateLimiter_ThrottlesCorrectly(t *testing.T) {
|
||||
bytesPerSec := int64(1024 * 1024)
|
||||
limiter := NewRateLimiter(bytesPerSec)
|
||||
|
||||
limiter.availableTokens = 0
|
||||
|
||||
start := time.Now()
|
||||
limiter.Wait(bytesPerSec / 2)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.GreaterOrEqual(t, elapsed, 400*time.Millisecond)
|
||||
assert.LessOrEqual(t, elapsed, 700*time.Millisecond)
|
||||
}
|
||||
48
backend/internal/features/backups/backups/download/di.go
Normal file
48
backend/internal/features/backups/backups/download/di.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var downloadTokenRepository = &DownloadTokenRepository{}
|
||||
|
||||
var downloadTracker = NewDownloadTracker(cache_utils.GetValkeyClient())
|
||||
|
||||
var bandwidthManager *BandwidthManager
|
||||
var downloadTokenService *DownloadTokenService
|
||||
var downloadTokenBackgroundService *DownloadTokenBackgroundService
|
||||
|
||||
func init() {
|
||||
env := config.GetEnv()
|
||||
throughputMBs := env.NodeNetworkThroughputMBs
|
||||
if throughputMBs == 0 {
|
||||
throughputMBs = 125
|
||||
}
|
||||
bandwidthManager = NewBandwidthManager(throughputMBs)
|
||||
|
||||
downloadTokenService = &DownloadTokenService{
|
||||
downloadTokenRepository,
|
||||
logger.GetLogger(),
|
||||
downloadTracker,
|
||||
bandwidthManager,
|
||||
}
|
||||
|
||||
downloadTokenBackgroundService = &DownloadTokenBackgroundService{
|
||||
downloadTokenService,
|
||||
logger.GetLogger(),
|
||||
}
|
||||
}
|
||||
|
||||
func GetDownloadTokenService() *DownloadTokenService {
|
||||
return downloadTokenService
|
||||
}
|
||||
|
||||
func GetDownloadTokenBackgroundService() *DownloadTokenBackgroundService {
|
||||
return downloadTokenBackgroundService
|
||||
}
|
||||
|
||||
func GetBandwidthManager() *BandwidthManager {
|
||||
return bandwidthManager
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
package backups_download
|
||||
|
||||
import "github.com/google/uuid"
|
||||
|
||||
type GenerateDownloadTokenResponse struct {
|
||||
Token string `json:"token"`
|
||||
Filename string `json:"filename"`
|
||||
BackupID uuid.UUID `json:"backupId"`
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package download_token
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"time"
|
||||
@@ -0,0 +1,101 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RateLimiter struct {
|
||||
mu sync.Mutex
|
||||
bytesPerSecond int64
|
||||
bucketSize int64
|
||||
availableTokens float64
|
||||
lastRefill time.Time
|
||||
}
|
||||
|
||||
func NewRateLimiter(bytesPerSecond int64) *RateLimiter {
|
||||
if bytesPerSecond <= 0 {
|
||||
bytesPerSecond = 1024 * 1024 * 100
|
||||
}
|
||||
|
||||
return &RateLimiter{
|
||||
bytesPerSecond: bytesPerSecond,
|
||||
bucketSize: bytesPerSecond * 2,
|
||||
availableTokens: float64(bytesPerSecond * 2),
|
||||
lastRefill: time.Now().UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) UpdateRate(bytesPerSecond int64) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
if bytesPerSecond <= 0 {
|
||||
bytesPerSecond = 1024 * 1024 * 100
|
||||
}
|
||||
|
||||
rl.bytesPerSecond = bytesPerSecond
|
||||
rl.bucketSize = bytesPerSecond * 2
|
||||
|
||||
if rl.availableTokens > float64(rl.bucketSize) {
|
||||
rl.availableTokens = float64(rl.bucketSize)
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) Wait(bytes int64) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
for {
|
||||
now := time.Now().UTC()
|
||||
elapsed := now.Sub(rl.lastRefill).Seconds()
|
||||
|
||||
tokensToAdd := elapsed * float64(rl.bytesPerSecond)
|
||||
rl.availableTokens += tokensToAdd
|
||||
if rl.availableTokens > float64(rl.bucketSize) {
|
||||
rl.availableTokens = float64(rl.bucketSize)
|
||||
}
|
||||
rl.lastRefill = now
|
||||
|
||||
if rl.availableTokens >= float64(bytes) {
|
||||
rl.availableTokens -= float64(bytes)
|
||||
return
|
||||
}
|
||||
|
||||
tokensNeeded := float64(bytes) - rl.availableTokens
|
||||
waitTime := time.Duration(tokensNeeded/float64(rl.bytesPerSecond)*1000) * time.Millisecond
|
||||
|
||||
if waitTime < time.Millisecond {
|
||||
waitTime = time.Millisecond
|
||||
}
|
||||
|
||||
rl.mu.Unlock()
|
||||
time.Sleep(waitTime)
|
||||
rl.mu.Lock()
|
||||
}
|
||||
}
|
||||
|
||||
type RateLimitedReader struct {
|
||||
reader io.ReadCloser
|
||||
rateLimiter *RateLimiter
|
||||
}
|
||||
|
||||
func NewRateLimitedReader(reader io.ReadCloser, limiter *RateLimiter) *RateLimitedReader {
|
||||
return &RateLimitedReader{
|
||||
reader: reader,
|
||||
rateLimiter: limiter,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RateLimitedReader) Read(p []byte) (n int, err error) {
|
||||
n, err = r.reader.Read(p)
|
||||
if n > 0 {
|
||||
r.rateLimiter.Wait(int64(n))
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *RateLimitedReader) Close() error {
|
||||
return r.reader.Close()
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package download_token
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
105
backend/internal/features/backups/backups/download/service.go
Normal file
105
backend/internal/features/backups/backups/download/service.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type DownloadTokenService struct {
|
||||
repository *DownloadTokenRepository
|
||||
logger *slog.Logger
|
||||
downloadTracker *DownloadTracker
|
||||
bandwidthManager *BandwidthManager
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) Generate(backupID, userID uuid.UUID) (string, error) {
|
||||
if s.downloadTracker.IsDownloadInProgress(userID) {
|
||||
return "", ErrDownloadAlreadyInProgress
|
||||
}
|
||||
|
||||
token := GenerateSecureToken()
|
||||
|
||||
downloadToken := &DownloadToken{
|
||||
Token: token,
|
||||
BackupID: backupID,
|
||||
UserID: userID,
|
||||
ExpiresAt: time.Now().UTC().Add(5 * time.Minute),
|
||||
Used: false,
|
||||
}
|
||||
|
||||
if err := s.repository.Create(downloadToken); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
s.logger.Info("Generated download token", "backupId", backupID, "userId", userID)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) ValidateAndConsume(
|
||||
token string,
|
||||
) (*DownloadToken, *RateLimiter, error) {
|
||||
dt, err := s.repository.FindByToken(token)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if dt == nil {
|
||||
return nil, nil, errors.New("invalid token")
|
||||
}
|
||||
|
||||
if dt.Used {
|
||||
return nil, nil, errors.New("token already used")
|
||||
}
|
||||
|
||||
if time.Now().UTC().After(dt.ExpiresAt) {
|
||||
return nil, nil, errors.New("token expired")
|
||||
}
|
||||
|
||||
if err := s.downloadTracker.AcquireDownloadLock(dt.UserID); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
rateLimiter, err := s.bandwidthManager.RegisterDownload(dt.UserID)
|
||||
if err != nil {
|
||||
s.downloadTracker.ReleaseDownloadLock(dt.UserID)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
dt.Used = true
|
||||
if err := s.repository.Update(dt); err != nil {
|
||||
s.logger.Error("Failed to mark token as used", "error", err)
|
||||
}
|
||||
|
||||
s.logger.Info("Token validated and consumed", "backupId", dt.BackupID, "userId", dt.UserID)
|
||||
return dt, rateLimiter, nil
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) RefreshDownloadLock(userID uuid.UUID) {
|
||||
s.downloadTracker.RefreshDownloadLock(userID)
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) ReleaseDownloadLock(userID uuid.UUID) {
|
||||
s.downloadTracker.ReleaseDownloadLock(userID)
|
||||
s.logger.Info("Released download lock", "userId", userID)
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) IsDownloadInProgress(userID uuid.UUID) bool {
|
||||
return s.downloadTracker.IsDownloadInProgress(userID)
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) UnregisterDownload(userID uuid.UUID) {
|
||||
s.bandwidthManager.UnregisterDownload(userID)
|
||||
s.logger.Info("Unregistered from bandwidth manager", "userId", userID)
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) CleanExpiredTokens() error {
|
||||
now := time.Now().UTC()
|
||||
if err := s.repository.DeleteExpired(now); err != nil {
|
||||
return err
|
||||
}
|
||||
s.logger.Debug("Cleaned expired download tokens")
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/valkey-io/valkey-go"
|
||||
)
|
||||
|
||||
const (
|
||||
downloadLockPrefix = "backup_download_lock:"
|
||||
downloadLockTTL = 5 * time.Second
|
||||
downloadLockValue = "1"
|
||||
downloadHeartbeatDelay = 3 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDownloadAlreadyInProgress = errors.New("download already in progress for this user")
|
||||
)
|
||||
|
||||
type DownloadTracker struct {
|
||||
cache *cache_utils.CacheUtil[string]
|
||||
}
|
||||
|
||||
func NewDownloadTracker(client valkey.Client) *DownloadTracker {
|
||||
return &DownloadTracker{
|
||||
cache: cache_utils.NewCacheUtil[string](client, downloadLockPrefix),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *DownloadTracker) AcquireDownloadLock(userID uuid.UUID) error {
|
||||
key := userID.String()
|
||||
|
||||
existingLock := t.cache.Get(key)
|
||||
if existingLock != nil {
|
||||
return ErrDownloadAlreadyInProgress
|
||||
}
|
||||
|
||||
value := downloadLockValue
|
||||
t.cache.Set(key, &value)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *DownloadTracker) RefreshDownloadLock(userID uuid.UUID) {
|
||||
key := userID.String()
|
||||
value := downloadLockValue
|
||||
t.cache.Set(key, &value)
|
||||
}
|
||||
|
||||
func (t *DownloadTracker) ReleaseDownloadLock(userID uuid.UUID) {
|
||||
key := userID.String()
|
||||
t.cache.Invalidate(key)
|
||||
}
|
||||
|
||||
func (t *DownloadTracker) IsDownloadInProgress(userID uuid.UUID) bool {
|
||||
key := userID.String()
|
||||
existingLock := t.cache.Get(key)
|
||||
return existingLock != nil
|
||||
}
|
||||
|
||||
func GetDownloadHeartbeatInterval() time.Duration {
|
||||
return downloadHeartbeatDelay
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
package download_token
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
"log/slog"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DownloadTokenBackgroundService struct {
|
||||
downloadTokenService *DownloadTokenService
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *DownloadTokenBackgroundService) Run() {
|
||||
s.logger.Info("Starting download token cleanup background service")
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
if config.IsShouldShutdown() {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
|
||||
s.logger.Error("Failed to clean expired download tokens", "error", err)
|
||||
}
|
||||
|
||||
time.Sleep(1 * time.Minute)
|
||||
}
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package download_token
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var downloadTokenRepository = &DownloadTokenRepository{}
|
||||
|
||||
var downloadTokenService = &DownloadTokenService{
|
||||
downloadTokenRepository,
|
||||
logger.GetLogger(),
|
||||
}
|
||||
|
||||
var downloadTokenBackgroundService = &DownloadTokenBackgroundService{
|
||||
downloadTokenService,
|
||||
logger.GetLogger(),
|
||||
}
|
||||
|
||||
func GetDownloadTokenService() *DownloadTokenService {
|
||||
return downloadTokenService
|
||||
}
|
||||
|
||||
func GetDownloadTokenBackgroundService() *DownloadTokenBackgroundService {
|
||||
return downloadTokenBackgroundService
|
||||
}
|
||||
@@ -1,69 +0,0 @@
|
||||
package download_token
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type DownloadTokenService struct {
|
||||
repository *DownloadTokenRepository
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) Generate(backupID, userID uuid.UUID) (string, error) {
|
||||
token := GenerateSecureToken()
|
||||
|
||||
downloadToken := &DownloadToken{
|
||||
Token: token,
|
||||
BackupID: backupID,
|
||||
UserID: userID,
|
||||
ExpiresAt: time.Now().UTC().Add(5 * time.Minute),
|
||||
Used: false,
|
||||
}
|
||||
|
||||
if err := s.repository.Create(downloadToken); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
s.logger.Info("Generated download token", "backupId", backupID, "userId", userID)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) ValidateAndConsume(token string) (*DownloadToken, error) {
|
||||
dt, err := s.repository.FindByToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if dt == nil {
|
||||
return nil, errors.New("invalid token")
|
||||
}
|
||||
|
||||
if dt.Used {
|
||||
return nil, errors.New("token already used")
|
||||
}
|
||||
|
||||
if time.Now().UTC().After(dt.ExpiresAt) {
|
||||
return nil, errors.New("token expired")
|
||||
}
|
||||
|
||||
dt.Used = true
|
||||
if err := s.repository.Update(dt); err != nil {
|
||||
s.logger.Error("Failed to mark token as used", "error", err)
|
||||
}
|
||||
|
||||
s.logger.Info("Token validated and consumed", "backupId", dt.BackupID)
|
||||
return dt, nil
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) CleanExpiredTokens() error {
|
||||
now := time.Now().UTC()
|
||||
if err := s.repository.DeleteExpired(now); err != nil {
|
||||
return err
|
||||
}
|
||||
s.logger.Debug("Cleaned expired download tokens")
|
||||
return nil
|
||||
}
|
||||
@@ -1,10 +1,9 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/encryption"
|
||||
"io"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type GetBackupsRequest struct {
|
||||
@@ -14,23 +13,17 @@ type GetBackupsRequest struct {
|
||||
}
|
||||
|
||||
type GetBackupsResponse struct {
|
||||
Backups []*Backup `json:"backups"`
|
||||
Total int64 `json:"total"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
Backups []*backups_core.Backup `json:"backups"`
|
||||
Total int64 `json:"total"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
}
|
||||
|
||||
type GenerateDownloadTokenResponse struct {
|
||||
Token string `json:"token"`
|
||||
Filename string `json:"filename"`
|
||||
BackupID uuid.UUID `json:"backupId"`
|
||||
}
|
||||
|
||||
type decryptionReaderCloser struct {
|
||||
type DecryptionReaderCloser struct {
|
||||
*encryption.DecryptionReader
|
||||
baseReader io.ReadCloser
|
||||
BaseReader io.ReadCloser
|
||||
}
|
||||
|
||||
func (r *decryptionReaderCloser) Close() error {
|
||||
return r.baseReader.Close()
|
||||
func (r *DecryptionReaderCloser) Close() error {
|
||||
return r.BaseReader.Close()
|
||||
}
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups/download_token"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
"databasus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -29,26 +28,27 @@ import (
|
||||
type BackupService struct {
|
||||
databaseService *databases.DatabaseService
|
||||
storageService *storages.StorageService
|
||||
backupRepository *BackupRepository
|
||||
backupRepository *backups_core.BackupRepository
|
||||
notifierService *notifiers.NotifierService
|
||||
notificationSender NotificationSender
|
||||
notificationSender backups_core.NotificationSender
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
secretKeyService *encryption_secrets.SecretKeyService
|
||||
fieldEncryptor util_encryption.FieldEncryptor
|
||||
|
||||
createBackupUseCase CreateBackupUsecase
|
||||
createBackupUseCase backups_core.CreateBackupUsecase
|
||||
|
||||
logger *slog.Logger
|
||||
|
||||
backupRemoveListeners []BackupRemoveListener
|
||||
backupRemoveListeners []backups_core.BackupRemoveListener
|
||||
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
auditLogService *audit_logs.AuditLogService
|
||||
backupContextManager *BackupContextManager
|
||||
downloadTokenService *download_token.DownloadTokenService
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
auditLogService *audit_logs.AuditLogService
|
||||
backupCancelManager *backups_cancellation.BackupCancelManager
|
||||
downloadTokenService *backups_download.DownloadTokenService
|
||||
backupSchedulerService *backuping.BackupsScheduler
|
||||
}
|
||||
|
||||
func (s *BackupService) AddBackupRemoveListener(listener BackupRemoveListener) {
|
||||
func (s *BackupService) AddBackupRemoveListener(listener backups_core.BackupRemoveListener) {
|
||||
s.backupRemoveListeners = append(s.backupRemoveListeners, listener)
|
||||
}
|
||||
|
||||
@@ -91,7 +91,7 @@ func (s *BackupService) MakeBackupWithAuth(
|
||||
return errors.New("insufficient permissions to create backup for this database")
|
||||
}
|
||||
|
||||
go s.MakeBackup(databaseID, true)
|
||||
s.backupSchedulerService.StartBackup(databaseID, true)
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Backup manually initiated for database: %s", database.Name),
|
||||
@@ -175,7 +175,7 @@ func (s *BackupService) DeleteBackup(
|
||||
return errors.New("insufficient permissions to delete backup for this database")
|
||||
}
|
||||
|
||||
if backup.Status == BackupStatusInProgress {
|
||||
if backup.Status == backups_core.BackupStatusInProgress {
|
||||
return errors.New("backup is in progress")
|
||||
}
|
||||
|
||||
@@ -192,260 +192,7 @@ func (s *BackupService) DeleteBackup(
|
||||
return s.deleteBackup(backup)
|
||||
}
|
||||
|
||||
func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
|
||||
database, err := s.databaseService.GetDatabaseByID(databaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get database by ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
lastBackup, err := s.backupRepository.FindLastByDatabaseID(databaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to find last backup by database ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if lastBackup != nil && lastBackup.Status == BackupStatusInProgress {
|
||||
s.logger.Error("Backup is in progress")
|
||||
return
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(databaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if backupConfig.StorageID == nil {
|
||||
s.logger.Error("Backup config storage ID is not defined")
|
||||
return
|
||||
}
|
||||
|
||||
storage, err := s.storageService.GetStorageByID(*backupConfig.StorageID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get storage by ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
backup := &Backup{
|
||||
DatabaseID: databaseID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusInProgress,
|
||||
|
||||
BackupSizeMb: 0,
|
||||
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error("Failed to save backup", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now().UTC()
|
||||
|
||||
backupProgressListener := func(
|
||||
completedMBs float64,
|
||||
) {
|
||||
backup.BackupSizeMb = completedMBs
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error("Failed to update backup progress", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s.backupContextManager.RegisterBackup(backup.ID, cancel)
|
||||
defer s.backupContextManager.UnregisterBackup(backup.ID)
|
||||
|
||||
backupMetadata, err := s.createBackupUseCase.Execute(
|
||||
ctx,
|
||||
backup.ID,
|
||||
backupConfig,
|
||||
database,
|
||||
storage,
|
||||
backupProgressListener,
|
||||
)
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
|
||||
// Check if backup was cancelled (not due to shutdown)
|
||||
isCancelled := strings.Contains(errMsg, "backup cancelled") ||
|
||||
strings.Contains(errMsg, "context canceled") ||
|
||||
errors.Is(err, context.Canceled)
|
||||
isShutdown := strings.Contains(errMsg, "shutdown")
|
||||
|
||||
if isCancelled && !isShutdown {
|
||||
backup.Status = BackupStatusCanceled
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error("Failed to save cancelled backup", "error", err)
|
||||
}
|
||||
|
||||
// Delete partial backup from storage
|
||||
storage, storageErr := s.storageService.GetStorageByID(backup.StorageID)
|
||||
if storageErr == nil {
|
||||
if deleteErr := storage.DeleteFile(s.fieldEncryptor, backup.ID); deleteErr != nil {
|
||||
s.logger.Error(
|
||||
"Failed to delete partial backup file",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
deleteErr,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
backup.FailMessage = &errMsg
|
||||
backup.Status = BackupStatusFailed
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
if updateErr := s.databaseService.SetBackupError(databaseID, errMsg); updateErr != nil {
|
||||
s.logger.Error(
|
||||
"Failed to update database last backup time",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"error",
|
||||
updateErr,
|
||||
)
|
||||
}
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error("Failed to save backup", "error", err)
|
||||
}
|
||||
|
||||
s.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupFailed,
|
||||
&errMsg,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
backup.Status = BackupStatusCompleted
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
// Update backup with encryption metadata if provided
|
||||
if backupMetadata != nil {
|
||||
backup.EncryptionSalt = backupMetadata.EncryptionSalt
|
||||
backup.EncryptionIV = backupMetadata.EncryptionIV
|
||||
backup.Encryption = backupMetadata.Encryption
|
||||
}
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error("Failed to save backup", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Update database last backup time
|
||||
now := time.Now().UTC()
|
||||
if updateErr := s.databaseService.SetLastBackupTime(databaseID, now); updateErr != nil {
|
||||
s.logger.Error(
|
||||
"Failed to update database last backup time",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
"error",
|
||||
updateErr,
|
||||
)
|
||||
}
|
||||
|
||||
if backup.Status != BackupStatusCompleted && !isLastTry {
|
||||
return
|
||||
}
|
||||
|
||||
s.SendBackupNotification(
|
||||
backupConfig,
|
||||
backup,
|
||||
backups_config.NotificationBackupSuccess,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
func (s *BackupService) SendBackupNotification(
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
backup *Backup,
|
||||
notificationType backups_config.BackupNotificationType,
|
||||
errorMessage *string,
|
||||
) {
|
||||
database, err := s.databaseService.GetDatabaseByID(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
workspace, err := s.workspaceService.GetWorkspaceByID(*database.WorkspaceID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, notifier := range database.Notifiers {
|
||||
if !slices.Contains(
|
||||
backupConfig.SendNotificationsOn,
|
||||
notificationType,
|
||||
) {
|
||||
continue
|
||||
}
|
||||
|
||||
title := ""
|
||||
switch notificationType {
|
||||
case backups_config.NotificationBackupFailed:
|
||||
title = fmt.Sprintf(
|
||||
"❌ Backup failed for database \"%s\" (workspace \"%s\")",
|
||||
database.Name,
|
||||
workspace.Name,
|
||||
)
|
||||
case backups_config.NotificationBackupSuccess:
|
||||
title = fmt.Sprintf(
|
||||
"✅ Backup completed for database \"%s\" (workspace \"%s\")",
|
||||
database.Name,
|
||||
workspace.Name,
|
||||
)
|
||||
}
|
||||
|
||||
message := ""
|
||||
if errorMessage != nil {
|
||||
message = *errorMessage
|
||||
} else {
|
||||
// Format size conditionally
|
||||
var sizeStr string
|
||||
if backup.BackupSizeMb < 1024 {
|
||||
sizeStr = fmt.Sprintf("%.2f MB", backup.BackupSizeMb)
|
||||
} else {
|
||||
sizeGB := backup.BackupSizeMb / 1024
|
||||
sizeStr = fmt.Sprintf("%.2f GB", sizeGB)
|
||||
}
|
||||
|
||||
// Format duration as "0m 0s 0ms"
|
||||
totalMs := backup.BackupDurationMs
|
||||
minutes := totalMs / (1000 * 60)
|
||||
seconds := (totalMs % (1000 * 60)) / 1000
|
||||
durationStr := fmt.Sprintf("%dm %ds", minutes, seconds)
|
||||
|
||||
message = fmt.Sprintf(
|
||||
"Backup completed successfully in %s.\nCompressed backup size: %s",
|
||||
durationStr,
|
||||
sizeStr,
|
||||
)
|
||||
}
|
||||
|
||||
s.notificationSender.SendNotification(
|
||||
¬ifier,
|
||||
title,
|
||||
message,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupService) GetBackup(backupID uuid.UUID) (*Backup, error) {
|
||||
func (s *BackupService) GetBackup(backupID uuid.UUID) (*backups_core.Backup, error) {
|
||||
return s.backupRepository.FindByID(backupID)
|
||||
}
|
||||
|
||||
@@ -475,11 +222,11 @@ func (s *BackupService) CancelBackup(
|
||||
return errors.New("insufficient permissions to cancel backup for this database")
|
||||
}
|
||||
|
||||
if backup.Status != BackupStatusInProgress {
|
||||
if backup.Status != backups_core.BackupStatusInProgress {
|
||||
return errors.New("backup is not in progress")
|
||||
}
|
||||
|
||||
if err := s.backupContextManager.CancelBackup(backupID); err != nil {
|
||||
if err := s.backupCancelManager.CancelBackup(backupID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -499,7 +246,7 @@ func (s *BackupService) CancelBackup(
|
||||
func (s *BackupService) GetBackupFile(
|
||||
user *users_models.User,
|
||||
backupID uuid.UUID,
|
||||
) (io.ReadCloser, *Backup, *databases.Database, error) {
|
||||
) (io.ReadCloser, *backups_core.Backup, *databases.Database, error) {
|
||||
backup, err := s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
@@ -545,7 +292,7 @@ func (s *BackupService) GetBackupFile(
|
||||
return reader, backup, database, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) deleteBackup(backup *Backup) error {
|
||||
func (s *BackupService) deleteBackup(backup *backups_core.Backup) error {
|
||||
for _, listener := range s.backupRemoveListeners {
|
||||
if err := listener.OnBeforeBackupRemove(backup); err != nil {
|
||||
return err
|
||||
@@ -571,7 +318,7 @@ func (s *BackupService) deleteBackup(backup *Backup) error {
|
||||
func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
|
||||
dbBackupsInProgress, err := s.backupRepository.FindByDatabaseIdAndStatus(
|
||||
databaseID,
|
||||
BackupStatusInProgress,
|
||||
backups_core.BackupStatusInProgress,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -680,16 +427,16 @@ func (s *BackupService) getBackupReader(backupID uuid.UUID) (io.ReadCloser, erro
|
||||
|
||||
s.logger.Info("Returning encrypted backup with decryption", "backupId", backupID)
|
||||
|
||||
return &decryptionReaderCloser{
|
||||
decryptionReader,
|
||||
fileReader,
|
||||
return &DecryptionReaderCloser{
|
||||
DecryptionReader: decryptionReader,
|
||||
BaseReader: fileReader,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) GenerateDownloadToken(
|
||||
user *users_models.User,
|
||||
backupID uuid.UUID,
|
||||
) (*GenerateDownloadTokenResponse, error) {
|
||||
) (*backups_download.GenerateDownloadTokenResponse, error) {
|
||||
backup, err := s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -725,20 +472,22 @@ func (s *BackupService) GenerateDownloadToken(
|
||||
database.WorkspaceID,
|
||||
)
|
||||
|
||||
return &GenerateDownloadTokenResponse{
|
||||
return &backups_download.GenerateDownloadTokenResponse{
|
||||
Token: token,
|
||||
Filename: filename,
|
||||
BackupID: backupID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) ValidateDownloadToken(token string) (*download_token.DownloadToken, error) {
|
||||
func (s *BackupService) ValidateDownloadToken(
|
||||
token string,
|
||||
) (*backups_download.DownloadToken, *backups_download.RateLimiter, error) {
|
||||
return s.downloadTokenService.ValidateAndConsume(token)
|
||||
}
|
||||
|
||||
func (s *BackupService) GetBackupFileWithoutAuth(
|
||||
backupID uuid.UUID,
|
||||
) (io.ReadCloser, *Backup, *databases.Database, error) {
|
||||
) (io.ReadCloser, *backups_core.Backup, *databases.Database, error) {
|
||||
backup, err := s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
@@ -759,7 +508,7 @@ func (s *BackupService) GetBackupFileWithoutAuth(
|
||||
|
||||
func (s *BackupService) WriteAuditLogForDownload(
|
||||
userID uuid.UUID,
|
||||
backup *Backup,
|
||||
backup *backups_core.Backup,
|
||||
database *databases.Database,
|
||||
) {
|
||||
s.auditLogService.WriteAuditLog(
|
||||
@@ -773,8 +522,24 @@ func (s *BackupService) WriteAuditLogForDownload(
|
||||
)
|
||||
}
|
||||
|
||||
func (s *BackupService) RefreshDownloadLock(userID uuid.UUID) {
|
||||
s.downloadTokenService.RefreshDownloadLock(userID)
|
||||
}
|
||||
|
||||
func (s *BackupService) ReleaseDownloadLock(userID uuid.UUID) {
|
||||
s.downloadTokenService.ReleaseDownloadLock(userID)
|
||||
}
|
||||
|
||||
func (s *BackupService) IsDownloadInProgress(userID uuid.UUID) bool {
|
||||
return s.downloadTokenService.IsDownloadInProgress(userID)
|
||||
}
|
||||
|
||||
func (s *BackupService) UnregisterDownload(userID uuid.UUID) {
|
||||
s.downloadTokenService.UnregisterDownload(userID)
|
||||
}
|
||||
|
||||
func (s *BackupService) generateBackupFilename(
|
||||
backup *Backup,
|
||||
backup *backups_core.Backup,
|
||||
database *databases.Database,
|
||||
) string {
|
||||
timestamp := backup.CreatedAt.Format("2006-01-02_15-04-05")
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
@@ -58,9 +59,9 @@ func WaitForBackupCompletion(
|
||||
newestBackup := backups[0]
|
||||
t.Logf("WaitForBackupCompletion: newest backup status: %s", newestBackup.Status)
|
||||
|
||||
if newestBackup.Status == BackupStatusCompleted ||
|
||||
newestBackup.Status == BackupStatusFailed ||
|
||||
newestBackup.Status == BackupStatusCanceled {
|
||||
if newestBackup.Status == backups_core.BackupStatusCompleted ||
|
||||
newestBackup.Status == backups_core.BackupStatusFailed ||
|
||||
newestBackup.Status == backups_core.BackupStatusCanceled {
|
||||
t.Logf(
|
||||
"WaitForBackupCompletion: backup finished with status %s",
|
||||
newestBackup.Status,
|
||||
|
||||
@@ -515,11 +515,13 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
|
||||
hasProcess := false
|
||||
hasAllPrivileges := false
|
||||
|
||||
escapedDB := strings.ReplaceAll(database, "_", "\\_")
|
||||
dbPattern := regexp.MustCompile(
|
||||
fmt.Sprintf("(?i)ON\\s+[`'\"]?(%s|\\*)[`'\"]?\\.\\*", regexp.QuoteMeta(escapedDB)),
|
||||
dbPatternStr := fmt.Sprintf(
|
||||
`(?i)ON\s+[\x60'"]?%s[\x60'"]?\s*\.\s*\*`,
|
||||
regexp.QuoteMeta(database),
|
||||
)
|
||||
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\.\*`)
|
||||
dbPattern := regexp.MustCompile(dbPatternStr)
|
||||
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\s*\.\s*\*`)
|
||||
allPrivilegesPattern := regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`)
|
||||
|
||||
for rows.Next() {
|
||||
var grant string
|
||||
@@ -527,23 +529,26 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
|
||||
return "", fmt.Errorf("failed to scan grant: %w", err)
|
||||
}
|
||||
|
||||
if regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`).MatchString(grant) {
|
||||
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
|
||||
hasAllPrivileges = true
|
||||
}
|
||||
isRelevantGrant := globalPattern.MatchString(grant) || dbPattern.MatchString(grant)
|
||||
|
||||
if allPrivilegesPattern.MatchString(grant) && isRelevantGrant {
|
||||
hasAllPrivileges = true
|
||||
}
|
||||
|
||||
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
|
||||
if isRelevantGrant {
|
||||
for _, priv := range backupPrivileges {
|
||||
if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) {
|
||||
privPattern := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(priv) + `\b`)
|
||||
if privPattern.MatchString(grant) {
|
||||
detectedPrivileges[priv] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if globalPattern.MatchString(grant) &&
|
||||
regexp.MustCompile(`(?i)\bPROCESS\b`).MatchString(grant) {
|
||||
hasProcess = true
|
||||
if globalPattern.MatchString(grant) {
|
||||
processPattern := regexp.MustCompile(`(?i)\bPROCESS\b`)
|
||||
if processPattern.MatchString(grant) {
|
||||
hasProcess = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -537,6 +537,163 @@ func Test_ReadOnlyUser_CannotDropOrAlterTables(t *testing.T) {
|
||||
dropUserSafe(container.DB, username)
|
||||
}
|
||||
|
||||
func Test_TestConnection_DatabaseSpecificPrivilegesWithGlobalProcess_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MariadbVersion
|
||||
port string
|
||||
}{
|
||||
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
|
||||
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
|
||||
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
|
||||
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
|
||||
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
|
||||
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
|
||||
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
|
||||
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
|
||||
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
|
||||
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
|
||||
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMariadbContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`DROP TABLE IF EXISTS privilege_test`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`CREATE TABLE privilege_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`INSERT INTO privilege_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
specificUsername := fmt.Sprintf("spec_%s", uuid.New().String()[:8])
|
||||
specificPassword := "specificpass123"
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
specificUsername,
|
||||
specificPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"GRANT SELECT, SHOW VIEW ON %s.* TO '%s'@'%%'",
|
||||
container.Database,
|
||||
specificUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"GRANT PROCESS ON *.* TO '%s'@'%%'",
|
||||
specificUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer dropUserSafe(container.DB, specificUsername)
|
||||
|
||||
mariadbModel := &MariadbDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: specificUsername,
|
||||
Password: specificPassword,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mariadbModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_TestConnection_DatabaseWithUnderscores_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToMariadbContainer(t, env.TestMariadb1011Port, tools.MariadbVersion1011)
|
||||
defer container.DB.Close()
|
||||
|
||||
underscoreDbName := "test_db_name"
|
||||
|
||||
_, err := container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", underscoreDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName))
|
||||
}()
|
||||
|
||||
underscoreDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
container.Username, container.Password, container.Host, container.Port, underscoreDbName)
|
||||
underscoreDB, err := sqlx.Connect("mysql", underscoreDSN)
|
||||
assert.NoError(t, err)
|
||||
defer underscoreDB.Close()
|
||||
|
||||
_, err = underscoreDB.Exec(`
|
||||
CREATE TABLE underscore_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(`INSERT INTO underscore_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
underscoreUsername := fmt.Sprintf("under%s", uuid.New().String()[:8])
|
||||
underscorePassword := "underscorepass123"
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
underscoreUsername,
|
||||
underscorePassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"GRANT SELECT, SHOW VIEW ON `%s`.* TO '%s'@'%%'",
|
||||
underscoreDbName,
|
||||
underscoreUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer dropUserSafe(underscoreDB, underscoreUsername)
|
||||
|
||||
mariadbModel := &MariadbDatabase{
|
||||
Version: tools.MariadbVersion1011,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: underscoreUsername,
|
||||
Password: underscorePassword,
|
||||
Database: &underscoreDbName,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mariadbModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
type MariadbContainer struct {
|
||||
Host string
|
||||
Port int
|
||||
|
||||
@@ -486,11 +486,13 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
|
||||
hasProcess := false
|
||||
hasAllPrivileges := false
|
||||
|
||||
escapedDB := strings.ReplaceAll(database, "_", "\\_")
|
||||
dbPattern := regexp.MustCompile(
|
||||
fmt.Sprintf("(?i)ON\\s+[`'\"]?(%s|\\*)[`'\"]?\\.\\*", regexp.QuoteMeta(escapedDB)),
|
||||
dbPatternStr := fmt.Sprintf(
|
||||
`(?i)ON\s+[\x60'"]?%s[\x60'"]?\s*\.\s*\*`,
|
||||
regexp.QuoteMeta(database),
|
||||
)
|
||||
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\.\*`)
|
||||
dbPattern := regexp.MustCompile(dbPatternStr)
|
||||
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\s*\.\s*\*`)
|
||||
allPrivilegesPattern := regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`)
|
||||
|
||||
for rows.Next() {
|
||||
var grant string
|
||||
@@ -498,23 +500,26 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
|
||||
return "", fmt.Errorf("failed to scan grant: %w", err)
|
||||
}
|
||||
|
||||
if regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`).MatchString(grant) {
|
||||
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
|
||||
hasAllPrivileges = true
|
||||
}
|
||||
isRelevantGrant := globalPattern.MatchString(grant) || dbPattern.MatchString(grant)
|
||||
|
||||
if allPrivilegesPattern.MatchString(grant) && isRelevantGrant {
|
||||
hasAllPrivileges = true
|
||||
}
|
||||
|
||||
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
|
||||
if isRelevantGrant {
|
||||
for _, priv := range backupPrivileges {
|
||||
if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) {
|
||||
privPattern := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(priv) + `\b`)
|
||||
if privPattern.MatchString(grant) {
|
||||
detectedPrivileges[priv] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if globalPattern.MatchString(grant) &&
|
||||
regexp.MustCompile(`(?i)\bPROCESS\b`).MatchString(grant) {
|
||||
hasProcess = true
|
||||
if globalPattern.MatchString(grant) {
|
||||
processPattern := regexp.MustCompile(`(?i)\bPROCESS\b`)
|
||||
if processPattern.MatchString(grant) {
|
||||
hasProcess = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -518,6 +518,162 @@ func Test_ReadOnlyUser_CannotDropOrAlterTables(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_TestConnection_DatabaseSpecificPrivilegesWithGlobalProcess_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MysqlVersion
|
||||
port string
|
||||
}{
|
||||
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port},
|
||||
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port},
|
||||
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port},
|
||||
{"MySQL 9", tools.MysqlVersion9, env.TestMysql90Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMysqlContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`DROP TABLE IF EXISTS privilege_test`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`CREATE TABLE privilege_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`INSERT INTO privilege_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
specificUsername := fmt.Sprintf("specific_%s", uuid.New().String()[:8])
|
||||
specificPassword := "specificpass123"
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
specificUsername,
|
||||
specificPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"GRANT SELECT, SHOW VIEW ON %s.* TO '%s'@'%%'",
|
||||
container.Database,
|
||||
specificUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"GRANT PROCESS ON *.* TO '%s'@'%%'",
|
||||
specificUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(
|
||||
fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", specificUsername),
|
||||
)
|
||||
}()
|
||||
|
||||
mysqlModel := &MysqlDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: specificUsername,
|
||||
Password: specificPassword,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mysqlModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_TestConnection_DatabaseWithUnderscores_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToMysqlContainer(t, env.TestMysql80Port, tools.MysqlVersion80)
|
||||
defer container.DB.Close()
|
||||
|
||||
underscoreDbName := "test_db_name"
|
||||
|
||||
_, err := container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", underscoreDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName))
|
||||
}()
|
||||
|
||||
underscoreDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
container.Username, container.Password, container.Host, container.Port, underscoreDbName)
|
||||
underscoreDB, err := sqlx.Connect("mysql", underscoreDSN)
|
||||
assert.NoError(t, err)
|
||||
defer underscoreDB.Close()
|
||||
|
||||
_, err = underscoreDB.Exec(`
|
||||
CREATE TABLE underscore_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(`INSERT INTO underscore_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
underscoreUsername := fmt.Sprintf("under_%s", uuid.New().String()[:8])
|
||||
underscorePassword := "underscorepass123"
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
underscoreUsername,
|
||||
underscorePassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec(fmt.Sprintf(
|
||||
"GRANT SELECT, SHOW VIEW ON `%s`.* TO '%s'@'%%'",
|
||||
underscoreDbName,
|
||||
underscoreUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = underscoreDB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = underscoreDB.Exec(fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", underscoreUsername))
|
||||
}()
|
||||
|
||||
mysqlModel := &MysqlDatabase{
|
||||
Version: tools.MysqlVersion80,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: underscoreUsername,
|
||||
Password: underscorePassword,
|
||||
Database: &underscoreDbName,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mysqlModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
type MysqlContainer struct {
|
||||
Host string
|
||||
Port int
|
||||
|
||||
@@ -761,12 +761,15 @@ func checkBackupPermissions(
|
||||
// Check SELECT privilege on at least one table (if tables exist)
|
||||
// Use pg_tables from pg_catalog which shows all tables regardless of user privileges
|
||||
var tableCount int
|
||||
|
||||
if len(includeSchemas) > 0 {
|
||||
// Check only tables in the specified schemas
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_catalog.pg_tables t
|
||||
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
|
||||
AND t.schemaname NOT LIKE 'pg_temp_%'
|
||||
AND t.schemaname NOT LIKE 'pg_toast_temp_%'
|
||||
AND t.schemaname = ANY($1::text[])
|
||||
`, includeSchemas).Scan(&tableCount)
|
||||
} else {
|
||||
@@ -775,6 +778,8 @@ func checkBackupPermissions(
|
||||
SELECT COUNT(*)
|
||||
FROM pg_catalog.pg_tables t
|
||||
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
|
||||
AND t.schemaname NOT LIKE 'pg_temp_%'
|
||||
AND t.schemaname NOT LIKE 'pg_toast_temp_%'
|
||||
`).Scan(&tableCount)
|
||||
}
|
||||
|
||||
@@ -785,12 +790,15 @@ func checkBackupPermissions(
|
||||
if tableCount > 0 {
|
||||
// Check if user has SELECT on at least one of these tables
|
||||
var selectableTableCount int
|
||||
|
||||
if len(includeSchemas) > 0 {
|
||||
// Check only tables in the specified schemas
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_catalog.pg_tables t
|
||||
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
|
||||
AND t.schemaname NOT LIKE 'pg_temp_%'
|
||||
AND t.schemaname NOT LIKE 'pg_toast_temp_%'
|
||||
AND t.schemaname = ANY($1::text[])
|
||||
AND has_table_privilege(current_user, quote_ident(t.schemaname) || '.' || quote_ident(t.tablename), 'SELECT')
|
||||
`, includeSchemas).Scan(&selectableTableCount)
|
||||
@@ -800,6 +808,8 @@ func checkBackupPermissions(
|
||||
SELECT COUNT(*)
|
||||
FROM pg_catalog.pg_tables t
|
||||
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
|
||||
AND t.schemaname NOT LIKE 'pg_temp_%'
|
||||
AND t.schemaname NOT LIKE 'pg_toast_temp_%'
|
||||
AND has_table_privilege(current_user, quote_ident(t.schemaname) || '.' || quote_ident(t.tablename), 'SELECT')
|
||||
`).Scan(&selectableTableCount)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package healthcheck_attempt
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
"context"
|
||||
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
|
||||
"log/slog"
|
||||
"time"
|
||||
@@ -13,18 +13,19 @@ type HealthcheckAttemptBackgroundService struct {
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *HealthcheckAttemptBackgroundService) Run() {
|
||||
func (s *HealthcheckAttemptBackgroundService) Run(ctx context.Context) {
|
||||
// first healthcheck immediately
|
||||
s.checkDatabases()
|
||||
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
if config.IsShouldShutdown() {
|
||||
break
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.checkDatabases()
|
||||
}
|
||||
|
||||
s.checkDatabases()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -675,6 +675,10 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
WebhookNotifier: &webhook_notifier.WebhookNotifier{
|
||||
WebhookURL: "https://webhook.example.com/test",
|
||||
WebhookMethod: webhook_notifier.WebhookMethodPOST,
|
||||
Headers: []webhook_notifier.WebhookHeader{
|
||||
{Key: "Authorization", Value: "Bearer my-secret-token"},
|
||||
{Key: "X-Custom-Header", Value: "custom-value"},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
@@ -687,14 +691,40 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
WebhookNotifier: &webhook_notifier.WebhookNotifier{
|
||||
WebhookURL: "https://webhook.example.com/updated",
|
||||
WebhookMethod: webhook_notifier.WebhookMethodGET,
|
||||
Headers: []webhook_notifier.WebhookHeader{
|
||||
{Key: "Authorization", Value: "Bearer updated-token"},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
|
||||
// No sensitive data to verify for webhook
|
||||
assert.NotEmpty(
|
||||
t,
|
||||
notifier.WebhookNotifier.WebhookURL,
|
||||
"WebhookURL should be visible",
|
||||
)
|
||||
// Verify header values are encrypted in DB
|
||||
assert.True(
|
||||
t,
|
||||
isEncrypted(notifier.WebhookNotifier.Headers[0].Value),
|
||||
"Header value should be encrypted in DB",
|
||||
)
|
||||
decrypted := decryptField(
|
||||
t,
|
||||
notifier.ID,
|
||||
notifier.WebhookNotifier.Headers[0].Value,
|
||||
)
|
||||
assert.Equal(t, "Bearer updated-token", decrypted)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
|
||||
// No sensitive data to hide for webhook
|
||||
assert.NotEmpty(
|
||||
t,
|
||||
notifier.WebhookNotifier.WebhookURL,
|
||||
"WebhookURL should be visible",
|
||||
)
|
||||
for _, header := range notifier.WebhookNotifier.Headers {
|
||||
assert.Empty(t, header.Value, "Header value should be hidden")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -905,7 +935,7 @@ func Test_CreateNotifier_AllSensitiveFieldsEncryptedInDB(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Webhook Notifier - WebhookURL encrypted",
|
||||
name: "Webhook Notifier - Header values encrypted, URL not encrypted",
|
||||
createNotifier: func(workspaceID uuid.UUID) *Notifier {
|
||||
return &Notifier{
|
||||
WorkspaceID: workspaceID,
|
||||
@@ -914,17 +944,48 @@ func Test_CreateNotifier_AllSensitiveFieldsEncryptedInDB(t *testing.T) {
|
||||
WebhookNotifier: &webhook_notifier.WebhookNotifier{
|
||||
WebhookURL: "https://webhook.example.com/test456",
|
||||
WebhookMethod: webhook_notifier.WebhookMethodPOST,
|
||||
Headers: []webhook_notifier.WebhookHeader{
|
||||
{Key: "Authorization", Value: "Bearer secret-token-12345"},
|
||||
{Key: "X-API-Key", Value: "api-key-67890"},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
|
||||
assert.True(
|
||||
assert.False(
|
||||
t,
|
||||
isEncrypted(notifier.WebhookNotifier.WebhookURL),
|
||||
"WebhookURL should be encrypted",
|
||||
"WebhookURL should NOT be encrypted",
|
||||
)
|
||||
decrypted := decryptField(t, notifier.ID, notifier.WebhookNotifier.WebhookURL)
|
||||
assert.Equal(t, "https://webhook.example.com/test456", decrypted)
|
||||
assert.Equal(
|
||||
t,
|
||||
"https://webhook.example.com/test456",
|
||||
notifier.WebhookNotifier.WebhookURL,
|
||||
)
|
||||
|
||||
assert.True(
|
||||
t,
|
||||
isEncrypted(notifier.WebhookNotifier.Headers[0].Value),
|
||||
"Header value should be encrypted",
|
||||
)
|
||||
decrypted1 := decryptField(
|
||||
t,
|
||||
notifier.ID,
|
||||
notifier.WebhookNotifier.Headers[0].Value,
|
||||
)
|
||||
assert.Equal(t, "Bearer secret-token-12345", decrypted1)
|
||||
|
||||
assert.True(
|
||||
t,
|
||||
isEncrypted(notifier.WebhookNotifier.Headers[1].Value),
|
||||
"Header value should be encrypted",
|
||||
)
|
||||
decrypted2 := decryptField(
|
||||
t,
|
||||
notifier.ID,
|
||||
notifier.WebhookNotifier.Headers[1].Value,
|
||||
)
|
||||
assert.Equal(t, "api-key-67890", decrypted2)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -21,6 +21,10 @@ type WebhookHeader struct {
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
// Before both WebhookURL, BodyTemplate and HeadersJSON were considered
|
||||
// as sensetive data and it was causing issues. Now only headers values
|
||||
// considered as sensetive data, but we try to decrypt webhook URL and
|
||||
// body template for backward combability
|
||||
type WebhookNotifier struct {
|
||||
NotifierID uuid.UUID `json:"notifierId" gorm:"primaryKey;column:notifier_id"`
|
||||
WebhookURL string `json:"webhookUrl" gorm:"not null;column:webhook_url"`
|
||||
@@ -58,6 +62,20 @@ func (t *WebhookNotifier) AfterFind(_ *gorm.DB) error {
|
||||
}
|
||||
}
|
||||
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
|
||||
if t.WebhookURL != "" {
|
||||
if decrypted, err := encryptor.Decrypt(t.NotifierID, t.WebhookURL); err == nil {
|
||||
t.WebhookURL = decrypted
|
||||
}
|
||||
}
|
||||
|
||||
if t.BodyTemplate != nil && *t.BodyTemplate != "" {
|
||||
if decrypted, err := encryptor.Decrypt(t.NotifierID, *t.BodyTemplate); err == nil {
|
||||
t.BodyTemplate = &decrypted
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -79,22 +97,24 @@ func (t *WebhookNotifier) Send(
|
||||
heading string,
|
||||
message string,
|
||||
) error {
|
||||
webhookURL, err := encryptor.Decrypt(t.NotifierID, t.WebhookURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
|
||||
if err := t.decryptHeadersForSending(encryptor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch t.WebhookMethod {
|
||||
case WebhookMethodGET:
|
||||
return t.sendGET(webhookURL, heading, message, logger)
|
||||
return t.sendGET(t.WebhookURL, heading, message, logger)
|
||||
case WebhookMethodPOST:
|
||||
return t.sendPOST(webhookURL, heading, message, logger)
|
||||
return t.sendPOST(t.WebhookURL, heading, message, logger)
|
||||
default:
|
||||
return fmt.Errorf("unsupported webhook method: %s", t.WebhookMethod)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WebhookNotifier) HideSensitiveData() {
|
||||
for i := range t.Headers {
|
||||
t.Headers[i].Value = ""
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WebhookNotifier) Update(incoming *WebhookNotifier) {
|
||||
@@ -105,14 +125,15 @@ func (t *WebhookNotifier) Update(incoming *WebhookNotifier) {
|
||||
}
|
||||
|
||||
func (t *WebhookNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
if t.WebhookURL != "" {
|
||||
encrypted, err := encryptor.Encrypt(t.NotifierID, t.WebhookURL)
|
||||
for i := range t.Headers {
|
||||
if t.Headers[i].Value != "" {
|
||||
encrypted, err := encryptor.Encrypt(t.NotifierID, t.Headers[i].Value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt header value: %w", err)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt webhook URL: %w", err)
|
||||
t.Headers[i].Value = encrypted
|
||||
}
|
||||
|
||||
t.WebhookURL = encrypted
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -241,3 +262,15 @@ func escapeJSONString(s string) string {
|
||||
|
||||
return string(b[1 : len(b)-1])
|
||||
}
|
||||
|
||||
func (t *WebhookNotifier) decryptHeadersForSending(encryptor encryption.FieldEncryptor) error {
|
||||
for i := range t.Headers {
|
||||
if t.Headers[i].Value != "" {
|
||||
if decrypted, err := encryptor.Decrypt(t.NotifierID, t.Headers[i].Value); err == nil {
|
||||
t.Headers[i].Value = decrypted
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package restores
|
||||
|
||||
import (
|
||||
"context"
|
||||
"databasus-backend/internal/features/restores/enums"
|
||||
"log/slog"
|
||||
)
|
||||
@@ -10,7 +11,7 @@ type RestoreBackgroundService struct {
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *RestoreBackgroundService) Run() {
|
||||
func (s *RestoreBackgroundService) Run(ctx context.Context) {
|
||||
if err := s.failRestoresInProgress(); err != nil {
|
||||
s.logger.Error("Failed to fail restores in progress", "error", err)
|
||||
panic(err)
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"databasus-backend/internal/config"
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/databases/databases/mysql"
|
||||
@@ -274,7 +275,7 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
var backup *backups.Backup
|
||||
var backup *backups_core.Backup
|
||||
var request RestoreBackupRequest
|
||||
|
||||
if tc.dbType == databases.DatabaseTypePostgres {
|
||||
@@ -321,7 +322,7 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
|
||||
}
|
||||
|
||||
// Set huge backup size (10 TB) that would fail disk validation if checked
|
||||
repo := &backups.BackupRepository{}
|
||||
repo := &backups_core.BackupRepository{}
|
||||
backup.BackupSizeMb = 10485760.0
|
||||
err := repo.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
@@ -368,7 +369,7 @@ func createTestDatabaseWithBackupForRestore(
|
||||
workspace *workspaces_models.Workspace,
|
||||
owner *users_dto.SignInResponseDTO,
|
||||
router *gin.Engine,
|
||||
) (*databases.Database, *backups.Backup) {
|
||||
) (*databases.Database, *backups_core.Backup) {
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
@@ -504,7 +505,7 @@ func createTestStorage(workspaceID uuid.UUID) *storages.Storage {
|
||||
func createTestBackup(
|
||||
database *databases.Database,
|
||||
owner *users_dto.SignInResponseDTO,
|
||||
) *backups.Backup {
|
||||
) *backups_core.Backup {
|
||||
fieldEncryptor := util_encryption.GetFieldEncryptor()
|
||||
userService := users_services.GetUserService()
|
||||
user, err := userService.GetUserFromToken(owner.Token)
|
||||
@@ -517,17 +518,17 @@ func createTestBackup(
|
||||
panic("No storage found for workspace")
|
||||
}
|
||||
|
||||
backup := &backups.Backup{
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storages[0].ID,
|
||||
Status: backups.BackupStatusCompleted,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10.5,
|
||||
BackupDurationMs: 1000,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
repo := &backups.BackupRepository{}
|
||||
repo := &backups_core.BackupRepository{}
|
||||
if err := repo.Save(backup); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/restores/enums"
|
||||
"time"
|
||||
|
||||
@@ -13,7 +13,7 @@ type Restore struct {
|
||||
Status enums.RestoreStatus `json:"status" gorm:"column:status;type:text;not null"`
|
||||
|
||||
BackupID uuid.UUID `json:"backupId" gorm:"column:backup_id;type:uuid;not null"`
|
||||
Backup *backups.Backup
|
||||
Backup *backups_core.Backup
|
||||
|
||||
FailMessage *string `json:"failMessage" gorm:"column:fail_message"`
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package restores
|
||||
import (
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/disk"
|
||||
@@ -36,7 +37,7 @@ type RestoreService struct {
|
||||
diskService *disk.DiskService
|
||||
}
|
||||
|
||||
func (s *RestoreService) OnBeforeBackupRemove(backup *backups.Backup) error {
|
||||
func (s *RestoreService) OnBeforeBackupRemove(backup *backups_core.Backup) error {
|
||||
restores, err := s.restoreRepository.FindByBackupID(backup.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -153,10 +154,10 @@ func (s *RestoreService) RestoreBackupWithAuth(
|
||||
}
|
||||
|
||||
func (s *RestoreService) RestoreBackup(
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
requestDTO RestoreBackupRequest,
|
||||
) error {
|
||||
if backup.Status != backups.BackupStatusCompleted {
|
||||
if backup.Status != backups_core.BackupStatusCompleted {
|
||||
return errors.New("backup is not completed")
|
||||
}
|
||||
|
||||
@@ -370,7 +371,7 @@ func (s *RestoreService) validateVersionCompatibility(
|
||||
}
|
||||
|
||||
func (s *RestoreService) validateDiskSpace(
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
requestDTO RestoreBackupRequest,
|
||||
) error {
|
||||
// Only validate disk space for PostgreSQL when file-based restore is needed:
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -40,7 +40,7 @@ func (uc *RestoreMariadbBackupUsecase) Execute(
|
||||
restoringToDB *databases.Database,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore models.Restore,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
) error {
|
||||
if originalDB.Type != databases.DatabaseTypeMariadb {
|
||||
@@ -99,7 +99,7 @@ func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
|
||||
mariadbBin string,
|
||||
args []string,
|
||||
password string,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
mdbConfig *mariadbtypes.MariadbDatabase,
|
||||
) error {
|
||||
@@ -163,7 +163,7 @@ func (uc *RestoreMariadbBackupUsecase) executeMariadbRestore(
|
||||
args []string,
|
||||
myCnfFile string,
|
||||
backupReader io.ReadCloser,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
) error {
|
||||
fullArgs := append([]string{"--defaults-file=" + myCnfFile}, args...)
|
||||
|
||||
@@ -226,7 +226,7 @@ func (uc *RestoreMariadbBackupUsecase) executeMariadbRestore(
|
||||
|
||||
func (uc *RestoreMariadbBackupUsecase) setupDecryption(
|
||||
reader io.Reader,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
) (io.Reader, error) {
|
||||
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
|
||||
return nil, fmt.Errorf("backup is encrypted but missing encryption metadata")
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -40,7 +40,7 @@ func (uc *RestoreMongodbBackupUsecase) Execute(
|
||||
restoringToDB *databases.Database,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore models.Restore,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
) error {
|
||||
if originalDB.Type != databases.DatabaseTypeMongodb {
|
||||
@@ -124,7 +124,7 @@ func (uc *RestoreMongodbBackupUsecase) buildMongorestoreArgs(
|
||||
func (uc *RestoreMongodbBackupUsecase) restoreFromStorage(
|
||||
mongorestoreBin string,
|
||||
args []string,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), restoreTimeout)
|
||||
@@ -166,7 +166,7 @@ func (uc *RestoreMongodbBackupUsecase) executeMongoRestore(
|
||||
mongorestoreBin string,
|
||||
args []string,
|
||||
backupReader io.ReadCloser,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
) error {
|
||||
cmd := exec.CommandContext(ctx, mongorestoreBin, args...)
|
||||
|
||||
@@ -231,7 +231,7 @@ func (uc *RestoreMongodbBackupUsecase) executeMongoRestore(
|
||||
|
||||
func (uc *RestoreMongodbBackupUsecase) setupDecryption(
|
||||
reader io.Reader,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
) (io.Reader, error) {
|
||||
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
|
||||
return nil, errors.New("encrypted backup missing salt or IV")
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -40,7 +40,7 @@ func (uc *RestoreMysqlBackupUsecase) Execute(
|
||||
restoringToDB *databases.Database,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore models.Restore,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
) error {
|
||||
if originalDB.Type != databases.DatabaseTypeMysql {
|
||||
@@ -98,7 +98,7 @@ func (uc *RestoreMysqlBackupUsecase) restoreFromStorage(
|
||||
mysqlBin string,
|
||||
args []string,
|
||||
password string,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
myConfig *mysqltypes.MysqlDatabase,
|
||||
) error {
|
||||
@@ -154,7 +154,7 @@ func (uc *RestoreMysqlBackupUsecase) executeMysqlRestore(
|
||||
args []string,
|
||||
myCnfFile string,
|
||||
backupReader io.ReadCloser,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
) error {
|
||||
fullArgs := append([]string{"--defaults-file=" + myCnfFile}, args...)
|
||||
|
||||
@@ -217,7 +217,7 @@ func (uc *RestoreMysqlBackupUsecase) executeMysqlRestore(
|
||||
|
||||
func (uc *RestoreMysqlBackupUsecase) setupDecryption(
|
||||
reader io.Reader,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
) (io.Reader, error) {
|
||||
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
|
||||
return nil, fmt.Errorf("backup is encrypted but missing encryption metadata")
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
"time"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
"databasus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -39,7 +39,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
|
||||
restoringToDB *databases.Database,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore models.Restore,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
isExcludeExtensions bool,
|
||||
) error {
|
||||
@@ -86,7 +86,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
|
||||
func (uc *RestorePostgresqlBackupUsecase) restoreCustomType(
|
||||
originalDB *databases.Database,
|
||||
pgBin string,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
pg *pgtypes.PostgresqlDatabase,
|
||||
isExcludeExtensions bool,
|
||||
@@ -113,7 +113,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreCustomType(
|
||||
func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
|
||||
originalDB *databases.Database,
|
||||
pgBin string,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
pg *pgtypes.PostgresqlDatabase,
|
||||
) error {
|
||||
@@ -321,7 +321,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
|
||||
func (uc *RestorePostgresqlBackupUsecase) restoreViaFile(
|
||||
originalDB *databases.Database,
|
||||
pgBin string,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
pg *pgtypes.PostgresqlDatabase,
|
||||
isExcludeExtensions bool,
|
||||
@@ -371,7 +371,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
|
||||
pgBin string,
|
||||
args []string,
|
||||
password string,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
pgConfig *pgtypes.PostgresqlDatabase,
|
||||
isExcludeExtensions bool,
|
||||
@@ -469,7 +469,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
|
||||
// downloadBackupToTempFile downloads backup data from storage to a temporary file
|
||||
func (uc *RestorePostgresqlBackupUsecase) downloadBackupToTempFile(
|
||||
ctx context.Context,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
) (string, func(), error) {
|
||||
// Create temporary directory for backup data
|
||||
|
||||
@@ -3,7 +3,7 @@ package usecases
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/restores/models"
|
||||
@@ -26,7 +26,7 @@ func (uc *RestoreBackupUsecase) Execute(
|
||||
restore models.Restore,
|
||||
originalDB *databases.Database,
|
||||
restoringToDB *databases.Database,
|
||||
backup *backups.Backup,
|
||||
backup *backups_core.Backup,
|
||||
storage *storages.Storage,
|
||||
isExcludeExtensions bool,
|
||||
) error {
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
package system_healthcheck
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
"databasus-backend/internal/features/disk"
|
||||
)
|
||||
|
||||
var healthcheckService = &HealthcheckService{
|
||||
disk.GetDiskService(),
|
||||
backups.GetBackupBackgroundService(),
|
||||
backuping.GetBackupsScheduler(),
|
||||
backuping.GetBackuperNode(),
|
||||
}
|
||||
var healthcheckController = &HealthcheckController{
|
||||
healthcheckService,
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
package system_healthcheck
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
"databasus-backend/internal/features/disk"
|
||||
"databasus-backend/internal/storage"
|
||||
"errors"
|
||||
@@ -9,7 +10,8 @@ import (
|
||||
|
||||
type HealthcheckService struct {
|
||||
diskService *disk.DiskService
|
||||
backupBackgroundService *backups.BackupBackgroundService
|
||||
backupBackgroundService *backuping.BackupsScheduler
|
||||
backuperNode *backuping.BackuperNode
|
||||
}
|
||||
|
||||
func (s *HealthcheckService) IsHealthy() error {
|
||||
@@ -29,8 +31,16 @@ func (s *HealthcheckService) IsHealthy() error {
|
||||
return errors.New("cannot connect to the database")
|
||||
}
|
||||
|
||||
if !s.backupBackgroundService.IsBackupsWorkerRunning() {
|
||||
return errors.New("backups are not running for more than 5 minutes")
|
||||
if config.GetEnv().IsPrimaryNode {
|
||||
if !s.backupBackgroundService.IsSchedulerRunning() {
|
||||
return errors.New("backups are not running for more than 5 minutes")
|
||||
}
|
||||
}
|
||||
|
||||
if config.GetEnv().IsBackupNode {
|
||||
if !s.backuperNode.IsBackuperRunning() {
|
||||
return errors.New("backuper node is not running for more than 5 minutes")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
mariadbtypes "databasus-backend/internal/features/databases/databases/mariadb"
|
||||
@@ -189,7 +189,7 @@ func testMariadbBackupRestoreForVersion(
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := "restoreddb_mariadb"
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
|
||||
@@ -286,7 +286,7 @@ func testMariadbBackupRestoreWithEncryptionForVersion(
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_config.BackupEncryptionEncrypted, backup.Encryption)
|
||||
|
||||
newDBName := "restoreddb_mariadb_encrypted"
|
||||
@@ -394,7 +394,7 @@ func testMariadbBackupRestoreWithReadOnlyUserForVersion(
|
||||
createBackupViaAPI(t, router, updatedDatabase.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, updatedDatabase.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := "restoreddb_mariadb_readonly"
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
mongodbtypes "databasus-backend/internal/features/databases/databases/mongodb"
|
||||
@@ -161,7 +161,7 @@ func testMongodbBackupRestoreForVersion(
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := "restoreddb_mongodb_" + uuid.New().String()[:8]
|
||||
|
||||
@@ -239,7 +239,7 @@ func testMongodbBackupRestoreWithEncryptionForVersion(
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_config.BackupEncryptionEncrypted, backup.Encryption)
|
||||
|
||||
newDBName := "restoreddb_mongodb_enc_" + uuid.New().String()[:8]
|
||||
@@ -328,7 +328,7 @@ func testMongodbBackupRestoreWithReadOnlyUserForVersion(
|
||||
createBackupViaAPI(t, router, updatedDatabase.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, updatedDatabase.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := "restoreddb_mongodb_ro_" + uuid.New().String()[:8]
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
mysqltypes "databasus-backend/internal/features/databases/databases/mysql"
|
||||
@@ -164,7 +164,7 @@ func testMysqlBackupRestoreForVersion(t *testing.T, mysqlVersion tools.MysqlVers
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := "restoreddb_mysql"
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
|
||||
@@ -261,7 +261,7 @@ func testMysqlBackupRestoreWithEncryptionForVersion(
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_config.BackupEncryptionEncrypted, backup.Encryption)
|
||||
|
||||
newDBName := "restoreddb_mysql_encrypted"
|
||||
@@ -369,7 +369,7 @@ func testMysqlBackupRestoreWithReadOnlyUserForVersion(
|
||||
createBackupViaAPI(t, router, updatedDatabase.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, updatedDatabase.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := "restoreddb_mysql_readonly"
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
pgtypes "databasus-backend/internal/features/databases/databases/postgresql"
|
||||
@@ -190,7 +191,7 @@ func Test_BackupAndRestoreSupabase_PublicSchemaOnly_RestoreIsSuccessful(t *testi
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
|
||||
_, err = supabaseDB.Exec(fmt.Sprintf(`DELETE FROM public.%s`, tableName))
|
||||
assert.NoError(t, err)
|
||||
@@ -410,7 +411,7 @@ func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string, cp
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := fmt.Sprintf("restoreddb_%s_cpu%d_%s", pgVersion, cpuCount, uuid.New().String()[:8])
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
|
||||
@@ -527,7 +528,7 @@ func testSchemaSelectionAllSchemasForVersion(t *testing.T, pgVersion string, por
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := fmt.Sprintf("restored_all_schemas_%s_%s", pgVersion, uuid.New().String()[:8])
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
|
||||
@@ -655,7 +656,7 @@ func testBackupRestoreWithExcludeExtensionsForVersion(t *testing.T, pgVersion st
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := fmt.Sprintf("restored_exclude_ext_%s_%s", pgVersion, uuid.New().String()[:8])
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
|
||||
@@ -789,7 +790,7 @@ func testBackupRestoreWithoutExcludeExtensionsForVersion(
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := fmt.Sprintf("restored_with_ext_%s_%s", pgVersion, uuid.New().String()[:8])
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
|
||||
@@ -928,7 +929,7 @@ func testBackupRestoreWithReadOnlyUserForVersion(t *testing.T, pgVersion string,
|
||||
createBackupViaAPI(t, router, updatedDatabase.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, updatedDatabase.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := fmt.Sprintf("restoreddb_readonly_%s", uuid.New().String()[:8])
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
|
||||
@@ -1048,7 +1049,7 @@ func testSchemaSelectionOnlySpecifiedSchemasForVersion(
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := fmt.Sprintf("restored_specific_schemas_%s_%s", pgVersion, uuid.New().String()[:8])
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
|
||||
@@ -1161,7 +1162,7 @@ func testBackupRestoreWithEncryptionForVersion(t *testing.T, pgVersion string, p
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_config.BackupEncryptionEncrypted, backup.Encryption)
|
||||
|
||||
newDBName := fmt.Sprintf("restoreddb_encrypted_%s", uuid.New().String()[:8])
|
||||
@@ -1242,7 +1243,7 @@ func waitForBackupCompletion(
|
||||
databaseID uuid.UUID,
|
||||
token string,
|
||||
timeout time.Duration,
|
||||
) *backups.Backup {
|
||||
) *backups_core.Backup {
|
||||
startTime := time.Now()
|
||||
pollInterval := 500 * time.Millisecond
|
||||
|
||||
@@ -1263,10 +1264,10 @@ func waitForBackupCompletion(
|
||||
|
||||
if len(response.Backups) > 0 {
|
||||
backup := response.Backups[0]
|
||||
if backup.Status == backups.BackupStatusCompleted {
|
||||
if backup.Status == backups_core.BackupStatusCompleted {
|
||||
return backup
|
||||
}
|
||||
if backup.Status == backups.BackupStatusFailed {
|
||||
if backup.Status == backups_core.BackupStatusFailed {
|
||||
failMsg := "unknown error"
|
||||
if backup.FailMessage != nil {
|
||||
failMsg = *backup.FailMessage
|
||||
|
||||
22
backend/internal/features/tests/setup_test.go
Normal file
22
backend/internal/features/tests/setup_test.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"databasus-backend/internal/features/backups/backups/backuping"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
cache_utils.ClearAllCache()
|
||||
|
||||
backuperNode := backuping.CreateTestBackuperNode()
|
||||
cancel := backuping.StartBackuperNodeForTest(&testing.T{}, backuperNode)
|
||||
|
||||
exitCode := m.Run()
|
||||
|
||||
backuping.StopBackuperNodeForTest(&testing.T{}, cancel, backuperNode)
|
||||
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
@@ -2,13 +2,12 @@ package users_controllers
|
||||
|
||||
import (
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
)
|
||||
|
||||
var userController = &UserController{
|
||||
users_services.GetUserService(),
|
||||
rate.NewLimiter(rate.Limit(3), 3), // 3 rps with 3 burst
|
||||
cache_utils.NewRateLimiter(cache_utils.GetValkeyClient()),
|
||||
}
|
||||
|
||||
var settingsController = &SettingsController{
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
func Test_AdminLifecycleE2E_CompletesSuccessfully(t *testing.T) {
|
||||
@@ -185,7 +184,6 @@ func createUserTestRouter() *gin.Engine {
|
||||
// Register protected routes with auth middleware
|
||||
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
|
||||
GetUserController().RegisterProtectedRoutes(protected.(*gin.RouterGroup))
|
||||
GetUserController().SetSignInLimiter(rate.NewLimiter(rate.Limit(100), 100))
|
||||
|
||||
// Setup audit log service
|
||||
users_services.GetUserService().SetAuditLogWriter(&AuditLogWriterStub{})
|
||||
|
||||
@@ -3,20 +3,21 @@ package users_controllers
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
user_dto "databasus-backend/internal/features/users/dto"
|
||||
users_errors "databasus-backend/internal/features/users/errors"
|
||||
user_middleware "databasus-backend/internal/features/users/middleware"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type UserController struct {
|
||||
userService *users_services.UserService
|
||||
signinLimiter *rate.Limiter
|
||||
userService *users_services.UserService
|
||||
rateLimiter *cache_utils.RateLimiter
|
||||
}
|
||||
|
||||
func (c *UserController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
@@ -39,10 +40,6 @@ func (c *UserController) RegisterProtectedRoutes(router *gin.RouterGroup) {
|
||||
router.POST("/users/invite", c.InviteUser)
|
||||
}
|
||||
|
||||
func (c *UserController) SetSignInLimiter(limiter *rate.Limiter) {
|
||||
c.signinLimiter = limiter
|
||||
}
|
||||
|
||||
// SignUp
|
||||
// @Summary Register a new user
|
||||
// @Description Register a new user with email and password
|
||||
@@ -81,8 +78,14 @@ func (c *UserController) SignUp(ctx *gin.Context) {
|
||||
// @Failure 429 {object} map[string]string "Rate limit exceeded"
|
||||
// @Router /users/signin [post]
|
||||
func (c *UserController) SignIn(ctx *gin.Context) {
|
||||
// We use rate limiter to prevent brute force attacks
|
||||
if !c.signinLimiter.Allow() {
|
||||
var request user_dto.SignInRequestDTO
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"})
|
||||
return
|
||||
}
|
||||
|
||||
allowed, _ := c.rateLimiter.CheckLimit(request.Email, "signin", 10, 1*time.Minute)
|
||||
if !allowed {
|
||||
ctx.JSON(
|
||||
http.StatusTooManyRequests,
|
||||
gin.H{"error": "Rate limit exceeded. Please try again later."},
|
||||
@@ -90,12 +93,6 @@ func (c *UserController) SignIn(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var request user_dto.SignInRequestDTO
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.userService.SignIn(&request)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
|
||||
@@ -1146,3 +1146,48 @@ func Test_GoogleOAuth_WithInvitedUser_ActivatesUser(t *testing.T) {
|
||||
assert.Equal(t, email, response.Email)
|
||||
assert.False(t, response.IsNewUser)
|
||||
}
|
||||
|
||||
func Test_SignIn_WithExcessiveAttempts_RateLimitEnforced(t *testing.T) {
|
||||
router := createUserTestRouter()
|
||||
email := "ratelimit" + uuid.New().String() + "@example.com"
|
||||
password := "testpassword123"
|
||||
|
||||
// Create a user first
|
||||
signupRequest := users_dto.SignUpRequestDTO{
|
||||
Email: email,
|
||||
Password: password,
|
||||
Name: "Rate Limit Test User",
|
||||
}
|
||||
test_utils.MakePostRequest(t, router, "/api/v1/users/signup", "", signupRequest, http.StatusOK)
|
||||
|
||||
// Make 10 sign-in attempts (should succeed)
|
||||
for range 10 {
|
||||
signinRequest := users_dto.SignInRequestDTO{
|
||||
Email: email,
|
||||
Password: password,
|
||||
}
|
||||
test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/users/signin",
|
||||
"",
|
||||
signinRequest,
|
||||
http.StatusOK,
|
||||
)
|
||||
}
|
||||
|
||||
// 11th attempt should be rate limited
|
||||
signinRequest := users_dto.SignInRequestDTO{
|
||||
Email: email,
|
||||
Password: password,
|
||||
}
|
||||
resp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/users/signin",
|
||||
"",
|
||||
signinRequest,
|
||||
http.StatusTooManyRequests,
|
||||
)
|
||||
assert.Contains(t, string(resp.Body), "Rate limit exceeded")
|
||||
}
|
||||
|
||||
4
backend/internal/util/cache/cache.go
vendored
4
backend/internal/util/cache/cache.go
vendored
@@ -41,6 +41,10 @@ func getCache() valkey.Client {
|
||||
return valkeyClient
|
||||
}
|
||||
|
||||
func GetValkeyClient() valkey.Client {
|
||||
return getCache()
|
||||
}
|
||||
|
||||
func TestCacheConnection() {
|
||||
// Get Valkey client from cache package
|
||||
client := getCache()
|
||||
|
||||
85
backend/internal/util/cache/rate_limiter.go
vendored
Normal file
85
backend/internal/util/cache/rate_limiter.go
vendored
Normal file
@@ -0,0 +1,85 @@
|
||||
package cache_utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/valkey-io/valkey-go"
|
||||
)
|
||||
|
||||
type RateLimiter struct {
|
||||
client valkey.Client
|
||||
}
|
||||
|
||||
func NewRateLimiter(client valkey.Client) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RateLimiter) CheckLimit(
|
||||
identifier string,
|
||||
endpoint string,
|
||||
maxRequests int,
|
||||
windowDuration time.Duration,
|
||||
) (bool, error) {
|
||||
requestID := uuid.New().String()
|
||||
keyPrefix := fmt.Sprintf("ratelimit:%s:%s", endpoint, identifier)
|
||||
fullKey := fmt.Sprintf("%s:%s", keyPrefix, requestID)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), DefaultCacheTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Set the key with TTL
|
||||
setCmd := r.client.B().
|
||||
Set().
|
||||
Key(fullKey).
|
||||
Value("1").
|
||||
ExSeconds(int64(windowDuration.Seconds())).
|
||||
Build()
|
||||
if err := r.client.Do(ctx, setCmd).Error(); err != nil {
|
||||
return true, fmt.Errorf("failed to set rate limit key: %w", err)
|
||||
}
|
||||
|
||||
// Count keys matching the pattern
|
||||
count, err := r.countKeys(keyPrefix)
|
||||
if err != nil {
|
||||
return true, fmt.Errorf("failed to count rate limit keys: %w", err)
|
||||
}
|
||||
|
||||
return count <= maxRequests, nil
|
||||
}
|
||||
|
||||
func (r *RateLimiter) countKeys(keyPrefix string) (int, error) {
|
||||
pattern := keyPrefix + ":*"
|
||||
cursor := uint64(0)
|
||||
totalCount := 0
|
||||
|
||||
for {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), DefaultCacheTimeout)
|
||||
|
||||
scanCmd := r.client.B().Scan().Cursor(cursor).Match(pattern).Count(100).Build()
|
||||
result := r.client.Do(ctx, scanCmd)
|
||||
cancel()
|
||||
|
||||
if result.Error() != nil {
|
||||
return 0, result.Error()
|
||||
}
|
||||
|
||||
scanResult, err := result.AsScanEntry()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
totalCount += len(scanResult.Elements)
|
||||
cursor = scanResult.Cursor
|
||||
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return totalCount, nil
|
||||
}
|
||||
@@ -112,6 +112,12 @@ export function EditWebhookNotifierComponent({ notifier, setNotifier, setUnsaved
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{notifier.id && (
|
||||
<div className="mb-1 text-xs text-orange-700">
|
||||
*Saved headers hidden for security reasons
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="w-full max-w-[500px]">
|
||||
{headers.map((header: WebhookHeader, index: number) => (
|
||||
<div key={index} className="mb-1 flex items-center gap-2">
|
||||
@@ -204,11 +210,12 @@ export function EditWebhookNotifierComponent({ notifier, setNotifier, setUnsaved
|
||||
<div className="text-xs font-semibold text-gray-500 dark:text-gray-400">
|
||||
Headers:
|
||||
</div>
|
||||
|
||||
{headers
|
||||
.filter((h) => h.key)
|
||||
.map((h, i) => (
|
||||
<div key={i} className="text-xs">
|
||||
{h.key}: {h.value || '(empty)'}
|
||||
{h.key}: {h.value || '(hidden)'}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
@@ -29,7 +29,7 @@ export function ShowWebhookNotifierComponent({ notifier }: Props) {
|
||||
.filter((h: WebhookHeader) => h.key)
|
||||
.map((h: WebhookHeader, i: number) => (
|
||||
<div key={i} className="text-gray-600">
|
||||
<span className="font-medium">{h.key}:</span> {h.value || '(empty)'}
|
||||
<span className="font-medium">{h.key}:</span> {h.value || '(hidden)'}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
Reference in New Issue
Block a user