From 230cc27ea6a18ca8f1a3cd4b01b521b215c874e7 Mon Sep 17 00:00:00 2001 From: Rostislav Dugin Date: Tue, 3 Mar 2026 12:20:23 +0300 Subject: [PATCH] FEATURE (backups): Add WAL API --- .gitignore | 1 + CLAUDE.md | 1 + backend/cmd/main.go | 11 +- .../backups/backups/backuping/scheduler.go | 11 +- .../backups/{ => controllers}/controller.go | 18 +- .../{ => controllers}/controller_test.go | 25 +- .../backups/backups/controllers/di.go | 23 + .../controllers/postgres_wal_controller.go | 291 ++++ .../postgres_wal_controller_test.go | 1224 +++++++++++++++++ .../backups/{ => controllers}/testing.go | 4 +- .../features/backups/backups/core/di.go | 7 + .../features/backups/backups/core/enums.go | 7 + .../features/backups/backups/core/model.go | 30 +- .../backups/backups/core/repository.go | 131 ++ .../internal/features/backups/backups/dto.go | 29 - .../features/backups/backups/dto/dto.go | 78 ++ .../backups/backups/encryption/setup.go | 45 + .../backups/backups/{ => services}/di.go | 28 +- .../backups/services/postgres_wal_service.go | 613 +++++++++ .../backups/backups/{ => services}/service.go | 79 +- .../usecases/mariadb/create_backup_uc.go | 29 +- .../usecases/mongodb/create_backup_uc.go | 31 +- .../usecases/mysql/create_backup_uc.go | 29 +- .../usecases/postgresql/create_backup_uc.go | 29 +- .../internal/features/databases/controller.go | 63 + .../features/databases/controller_test.go | 224 ++- .../databases/databases/postgresql/model.go | 80 +- backend/internal/features/databases/dto.go | 4 + backend/internal/features/databases/model.go | 16 +- .../internal/features/databases/repository.go | 12 + .../internal/features/databases/service.go | 139 +- backend/internal/features/intervals/model.go | 125 ++ .../internal/features/intervals/model_test.go | 262 ++++ .../features/restores/controller_test.go | 4 +- backend/internal/features/restores/di.go | 6 +- .../features/restores/restoring/di.go | 58 +- .../features/restores/restoring/restorer.go | 4 +- .../restores/restoring/restorer_test.go | 6 +- .../features/restores/restoring/scheduler.go | 4 +- .../restores/restoring/scheduler_test.go | 16 +- .../features/restores/restoring/testing.go | 64 +- backend/internal/features/restores/service.go | 4 +- backend/internal/features/restores/testing.go | 6 +- .../features/storages/models/local/model.go | 9 +- .../internal/features/test_once_protection.go | 6 +- .../tests/postgresql_backup_restore_test.go | 9 +- backend/internal/util/wal/calculator.go | 125 ++ backend/internal/util/wal/calculator_test.go | 221 +++ .../20260306045548_add_wal_properties.sql | 72 + 49 files changed, 3941 insertions(+), 372 deletions(-) create mode 100644 CLAUDE.md rename backend/internal/features/backups/backups/{ => controllers}/controller.go (95%) rename backend/internal/features/backups/backups/{ => controllers}/controller_test.go (98%) create mode 100644 backend/internal/features/backups/backups/controllers/di.go create mode 100644 backend/internal/features/backups/backups/controllers/postgres_wal_controller.go create mode 100644 backend/internal/features/backups/backups/controllers/postgres_wal_controller_test.go rename backend/internal/features/backups/backups/{ => controllers}/testing.go (96%) create mode 100644 backend/internal/features/backups/backups/core/di.go delete mode 100644 backend/internal/features/backups/backups/dto.go create mode 100644 backend/internal/features/backups/backups/dto/dto.go create mode 100644 backend/internal/features/backups/backups/encryption/setup.go rename backend/internal/features/backups/backups/{ => services}/di.go (85%) create mode 100644 backend/internal/features/backups/backups/services/postgres_wal_service.go rename backend/internal/features/backups/backups/{ => services}/service.go (96%) create mode 100644 backend/internal/util/wal/calculator.go create mode 100644 backend/internal/util/wal/calculator_test.go create mode 100644 backend/migrations/20260306045548_add_wal_properties.sql diff --git a/.gitignore b/.gitignore index 4a9b959..16cf567 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ node_modules/ .DS_Store /scripts .vscode/settings.json +.claude \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..cc1c728 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1 @@ +Look at @AGENTS.md \ No newline at end of file diff --git a/backend/cmd/main.go b/backend/cmd/main.go index 0fa9a27..f1b7e44 100644 --- a/backend/cmd/main.go +++ b/backend/cmd/main.go @@ -14,9 +14,10 @@ import ( "databasus-backend/internal/config" "databasus-backend/internal/features/audit_logs" - "databasus-backend/internal/features/backups/backups" "databasus-backend/internal/features/backups/backups/backuping" + backups_controllers "databasus-backend/internal/features/backups/backups/controllers" backups_download "databasus-backend/internal/features/backups/backups/download" + backups_services "databasus-backend/internal/features/backups/backups/services" backups_config "databasus-backend/internal/features/backups/config" "databasus-backend/internal/features/databases" "databasus-backend/internal/features/disk" @@ -209,7 +210,9 @@ func setUpRoutes(r *gin.Engine) { userController := users_controllers.GetUserController() userController.RegisterRoutes(v1) system_healthcheck.GetHealthcheckController().RegisterRoutes(v1) - backups.GetBackupController().RegisterPublicRoutes(v1) + backups_controllers.GetBackupController().RegisterPublicRoutes(v1) + backups_controllers.GetPostgresWalBackupController().RegisterRoutes(v1) + databases.GetDatabaseController().RegisterPublicRoutes(v1) // Setup auth middleware userService := users_services.GetUserService() @@ -226,7 +229,7 @@ func setUpRoutes(r *gin.Engine) { notifiers.GetNotifierController().RegisterRoutes(protected) storages.GetStorageController().RegisterRoutes(protected) databases.GetDatabaseController().RegisterRoutes(protected) - backups.GetBackupController().RegisterRoutes(protected) + backups_controllers.GetBackupController().RegisterRoutes(protected) restores.GetRestoreController().RegisterRoutes(protected) healthcheck_config.GetHealthcheckConfigController().RegisterRoutes(protected) healthcheck_attempt.GetHealthcheckAttemptController().RegisterRoutes(protected) @@ -238,7 +241,7 @@ func setUpRoutes(r *gin.Engine) { func setUpDependencies() { databases.SetupDependencies() - backups.SetupDependencies() + backups_services.SetupDependencies() restores.SetupDependencies() healthcheck_config.SetupDependencies() audit_logs.SetupDependencies() diff --git a/backend/internal/features/backups/backups/backuping/scheduler.go b/backend/internal/features/backups/backups/backuping/scheduler.go index 27d211d..9113609 100644 --- a/backend/internal/features/backups/backups/backuping/scheduler.go +++ b/backend/internal/features/backups/backups/backuping/scheduler.go @@ -15,7 +15,6 @@ import ( backups_config "databasus-backend/internal/features/backups/config" "databasus-backend/internal/features/databases" task_cancellation "databasus-backend/internal/features/tasks/cancellation" - files_utils "databasus-backend/internal/util/files" ) const ( @@ -171,13 +170,7 @@ func (s *BackupsScheduler) StartBackup(database *databases.Database, isCallNotif timestamp := time.Now().UTC() backup := &backups_core.Backup{ - ID: backupID, - FileName: fmt.Sprintf( - "%s-%s-%s", - files_utils.SanitizeFilename(database.Name), - timestamp.Format("20060102-150405"), - backupID.String(), - ), + ID: backupID, DatabaseID: backupConfig.DatabaseID, StorageID: *backupConfig.StorageID, Status: backups_core.BackupStatusInProgress, @@ -185,6 +178,8 @@ func (s *BackupsScheduler) StartBackup(database *databases.Database, isCallNotif CreatedAt: timestamp, } + backup.GenerateFilename(database.Name) + if err := s.backupRepository.Save(backup); err != nil { s.logger.Error( "Failed to save backup", diff --git a/backend/internal/features/backups/backups/controller.go b/backend/internal/features/backups/backups/controllers/controller.go similarity index 95% rename from backend/internal/features/backups/backups/controller.go rename to backend/internal/features/backups/backups/controllers/controller.go index 4fd8c8f..8530816 100644 --- a/backend/internal/features/backups/backups/controller.go +++ b/backend/internal/features/backups/backups/controllers/controller.go @@ -1,9 +1,11 @@ -package backups +package backups_controllers import ( "context" backups_core "databasus-backend/internal/features/backups/backups/core" backups_download "databasus-backend/internal/features/backups/backups/download" + backups_dto "databasus-backend/internal/features/backups/backups/dto" + backups_services "databasus-backend/internal/features/backups/backups/services" "databasus-backend/internal/features/databases" users_middleware "databasus-backend/internal/features/users/middleware" files_utils "databasus-backend/internal/util/files" @@ -17,7 +19,7 @@ import ( ) type BackupController struct { - backupService *BackupService + backupService *backups_services.BackupService } func (c *BackupController) RegisterRoutes(router *gin.RouterGroup) { @@ -42,7 +44,7 @@ func (c *BackupController) RegisterPublicRoutes(router *gin.RouterGroup) { // @Param database_id query string true "Database ID" // @Param limit query int false "Number of items per page" default(10) // @Param offset query int false "Offset for pagination" default(0) -// @Success 200 {object} GetBackupsResponse +// @Success 200 {object} backups_dto.GetBackupsResponse // @Failure 400 // @Failure 401 // @Failure 500 @@ -54,7 +56,7 @@ func (c *BackupController) GetBackups(ctx *gin.Context) { return } - var request GetBackupsRequest + var request backups_dto.GetBackupsRequest if err := ctx.ShouldBindQuery(&request); err != nil { ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return @@ -81,7 +83,7 @@ func (c *BackupController) GetBackups(ctx *gin.Context) { // @Tags backups // @Accept json // @Produce json -// @Param request body MakeBackupRequest true "Backup creation data" +// @Param request body backups_dto.MakeBackupRequest true "Backup creation data" // @Success 200 {object} map[string]string // @Failure 400 // @Failure 401 @@ -94,7 +96,7 @@ func (c *BackupController) MakeBackup(ctx *gin.Context) { return } - var request MakeBackupRequest + var request backups_dto.MakeBackupRequest if err := ctx.ShouldBindJSON(&request); err != nil { ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return @@ -310,10 +312,6 @@ func (c *BackupController) GetFile(ctx *gin.Context) { c.backupService.WriteAuditLogForDownload(downloadToken.UserID, backup, database) } -type MakeBackupRequest struct { - DatabaseID uuid.UUID `json:"database_id" binding:"required"` -} - func (c *BackupController) generateBackupFilename( backup *backups_core.Backup, database *databases.Database, diff --git a/backend/internal/features/backups/backups/controller_test.go b/backend/internal/features/backups/backups/controllers/controller_test.go similarity index 98% rename from backend/internal/features/backups/backups/controller_test.go rename to backend/internal/features/backups/backups/controllers/controller_test.go index eac9fef..8d0992b 100644 --- a/backend/internal/features/backups/backups/controller_test.go +++ b/backend/internal/features/backups/backups/controllers/controller_test.go @@ -1,4 +1,4 @@ -package backups +package backups_controllers import ( "context" @@ -24,11 +24,14 @@ import ( backups_common "databasus-backend/internal/features/backups/backups/common" backups_core "databasus-backend/internal/features/backups/backups/core" backups_download "databasus-backend/internal/features/backups/backups/download" + backups_dto "databasus-backend/internal/features/backups/backups/dto" + backups_services "databasus-backend/internal/features/backups/backups/services" backups_config "databasus-backend/internal/features/backups/config" "databasus-backend/internal/features/databases" "databasus-backend/internal/features/databases/databases/postgresql" "databasus-backend/internal/features/storages" local_storage "databasus-backend/internal/features/storages/models/local" + task_cancellation "databasus-backend/internal/features/tasks/cancellation" users_dto "databasus-backend/internal/features/users/dto" users_enums "databasus-backend/internal/features/users/enums" users_services "databasus-backend/internal/features/users/services" @@ -119,7 +122,7 @@ func Test_GetBackups_PermissionsEnforced(t *testing.T) { ) if tt.expectSuccess { - var response GetBackupsResponse + var response backups_dto.GetBackupsResponse err := json.Unmarshal(testResp.Body, &response) assert.NoError(t, err) assert.GreaterOrEqual(t, len(response.Backups), 1) @@ -214,7 +217,7 @@ func Test_CreateBackup_PermissionsEnforced(t *testing.T) { testUserToken = nonMember.Token } - request := MakeBackupRequest{DatabaseID: database.ID} + request := backups_dto.MakeBackupRequest{DatabaseID: database.ID} testResp := test_utils.MakePostRequest( t, router, @@ -245,7 +248,7 @@ func Test_CreateBackup_AuditLogWritten(t *testing.T) { database := createTestDatabase("Test Database", workspace.ID, owner.Token, router) enableBackupForDatabase(database.ID) - request := MakeBackupRequest{DatabaseID: database.ID} + request := backups_dto.MakeBackupRequest{DatabaseID: database.ID} test_utils.MakePostRequest( t, router, @@ -373,7 +376,7 @@ func Test_DeleteBackup_PermissionsEnforced(t *testing.T) { ownerUser, err := userService.GetUserFromToken(owner.Token) assert.NoError(t, err) - response, err := GetBackupService().GetBackups(ownerUser, database.ID, 10, 0) + response, err := backups_services.GetBackupService().GetBackups(ownerUser, database.ID, 10, 0) assert.NoError(t, err) assert.Equal(t, 0, len(response.Backups)) } @@ -999,7 +1002,7 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) { assert.NoError(t, err) // Register a cancellable context for the backup - GetBackupService().taskCancelManager.RegisterTask(backup.ID, func() {}) + task_cancellation.GetTaskCancelManager().RegisterTask(backup.ID, func() {}) resp := test_utils.MakePostRequest( t, @@ -1091,7 +1094,7 @@ func Test_ConcurrentDownloadPrevention(t *testing.T) { time.Sleep(50 * time.Millisecond) - service := GetBackupService() + service := backups_services.GetBackupService() if !service.IsDownloadInProgress(owner.UserID) { t.Log("Warning: First download completed before we could test concurrency") <-downloadComplete @@ -1192,7 +1195,7 @@ func Test_GenerateDownloadToken_BlockedWhenDownloadInProgress(t *testing.T) { time.Sleep(50 * time.Millisecond) - service := GetBackupService() + service := backups_services.GetBackupService() if !service.IsDownloadInProgress(owner.UserID) { t.Log("Warning: First download completed before we could test token generation blocking") <-downloadComplete @@ -1268,7 +1271,7 @@ func Test_MakeBackup_VerifyBackupAndMetadataFilesExistInStorage(t *testing.T) { initialBackups, err := backupRepo.FindByDatabaseID(database.ID) assert.NoError(t, err) - request := MakeBackupRequest{DatabaseID: database.ID} + request := backups_dto.MakeBackupRequest{DatabaseID: database.ID} test_utils.MakePostRequest( t, router, @@ -1502,7 +1505,7 @@ func createTestBackup( } func createExpiredDownloadToken(backupID, userID uuid.UUID) string { - tokenService := GetBackupService().downloadTokenService + tokenService := backups_download.GetDownloadTokenService() token, err := tokenService.Generate(backupID, userID) if err != nil { panic(fmt.Sprintf("Failed to generate download token: %v", err)) @@ -1843,7 +1846,7 @@ func Test_DeleteBackup_RemovesBackupAndMetadataFilesFromDisk(t *testing.T) { initialBackups, err := backupRepo.FindByDatabaseID(database.ID) assert.NoError(t, err) - request := MakeBackupRequest{DatabaseID: database.ID} + request := backups_dto.MakeBackupRequest{DatabaseID: database.ID} test_utils.MakePostRequest( t, router, diff --git a/backend/internal/features/backups/backups/controllers/di.go b/backend/internal/features/backups/backups/controllers/di.go new file mode 100644 index 0000000..5085fdf --- /dev/null +++ b/backend/internal/features/backups/backups/controllers/di.go @@ -0,0 +1,23 @@ +package backups_controllers + +import ( + backups_services "databasus-backend/internal/features/backups/backups/services" + "databasus-backend/internal/features/databases" +) + +var backupController = &BackupController{ + backups_services.GetBackupService(), +} + +func GetBackupController() *BackupController { + return backupController +} + +var postgresWalBackupController = &PostgreWalBackupController{ + databases.GetDatabaseService(), + backups_services.GetWalService(), +} + +func GetPostgresWalBackupController() *PostgreWalBackupController { + return postgresWalBackupController +} diff --git a/backend/internal/features/backups/backups/controllers/postgres_wal_controller.go b/backend/internal/features/backups/backups/controllers/postgres_wal_controller.go new file mode 100644 index 0000000..4de5c53 --- /dev/null +++ b/backend/internal/features/backups/backups/controllers/postgres_wal_controller.go @@ -0,0 +1,291 @@ +package backups_controllers + +import ( + "io" + "net/http" + "strconv" + + backups_core "databasus-backend/internal/features/backups/backups/core" + backups_dto "databasus-backend/internal/features/backups/backups/dto" + backups_services "databasus-backend/internal/features/backups/backups/services" + "databasus-backend/internal/features/databases" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" +) + +// PostgreWalBackupController handles WAL backup endpoints used by the databasus-cli agent. +// Authentication is via a plain agent token in the Authorization header (no Bearer prefix). +type PostgreWalBackupController struct { + databaseService *databases.DatabaseService + walService *backups_services.PostgreWalBackupService +} + +func (c *PostgreWalBackupController) RegisterRoutes(router *gin.RouterGroup) { + walRoutes := router.Group("/backups/postgres/wal") + + walRoutes.GET("/next-full-backup-time", c.GetNextFullBackupTime) + walRoutes.POST("/error", c.ReportError) + walRoutes.POST("/upload", c.Upload) + walRoutes.GET("/restore/plan", c.GetRestorePlan) + walRoutes.GET("/restore/download", c.DownloadBackupFile) +} + +// GetNextFullBackupTime +// @Summary Get next full backup time +// @Description Returns the next scheduled full basebackup time for the authenticated database +// @Tags backups-wal +// @Produce json +// @Security AgentToken +// @Success 200 {object} backups_dto.GetNextFullBackupTimeResponse +// @Failure 401 {object} map[string]string +// @Failure 500 {object} map[string]string +// @Router /backups/postgres/wal/next-full-backup-time [get] +func (c *PostgreWalBackupController) GetNextFullBackupTime(ctx *gin.Context) { + database, err := c.getDatabase(ctx) + if err != nil { + ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"}) + return + } + + response, err := c.walService.GetNextFullBackupTime(database) + if err != nil { + ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + ctx.JSON(http.StatusOK, response) +} + +// ReportError +// @Summary Report agent error +// @Description Records a fatal error from the agent against the database record and marks it as errored +// @Tags backups-wal +// @Accept json +// @Security AgentToken +// @Param request body backups_dto.ReportErrorRequest true "Error details" +// @Success 200 +// @Failure 400 {object} map[string]string +// @Failure 401 {object} map[string]string +// @Failure 500 {object} map[string]string +// @Router /backups/postgres/wal/error [post] +func (c *PostgreWalBackupController) ReportError(ctx *gin.Context) { + database, err := c.getDatabase(ctx) + if err != nil { + ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"}) + return + } + + var request backups_dto.ReportErrorRequest + if err := ctx.ShouldBindJSON(&request); err != nil { + ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if err := c.walService.ReportError(database, request.Error); err != nil { + ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + ctx.Status(http.StatusOK) +} + +// Upload +// @Summary Stream upload a basebackup or WAL segment +// @Description Accepts a zstd-compressed binary stream and stores it in the database's configured storage. +// The server generates the storage filename; agents do not control the destination path. +// For WAL segment uploads the server validates the WAL chain and returns 409 if a gap is detected +// or 400 if no full backup exists yet (agent should trigger a full basebackup in both cases). +// @Tags backups-wal +// @Accept application/octet-stream +// @Produce json +// @Security AgentToken +// @Param X-Upload-Type header string true "Upload type" Enums(basebackup, wal) +// @Param X-Wal-Segment-Name header string false "24-hex WAL segment identifier (required for wal uploads, e.g. 0000000100000001000000AB)" +// @Param X-Wal-Segment-Size header int false "WAL segment size in bytes reported by the PostgreSQL instance (default: 16777216)" +// @Param fullBackupWalStartSegment query string false "First WAL segment needed to make the basebackup consistent (required for basebackup uploads)" +// @Param fullBackupWalStopSegment query string false "Last WAL segment included in the basebackup (required for basebackup uploads)" +// @Success 204 +// @Failure 400 {object} backups_dto.UploadGapResponse "No full backup exists (error: no_full_backup)" +// @Failure 401 {object} map[string]string +// @Failure 409 {object} backups_dto.UploadGapResponse "WAL chain gap detected (error: gap_detected)" +// @Failure 500 {object} map[string]string +// @Router /backups/postgres/wal/upload [post] +func (c *PostgreWalBackupController) Upload(ctx *gin.Context) { + database, err := c.getDatabase(ctx) + if err != nil { + ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"}) + return + } + + uploadType := backups_core.PgWalUploadType(ctx.GetHeader("X-Upload-Type")) + if uploadType != backups_core.PgWalUploadTypeBasebackup && + uploadType != backups_core.PgWalUploadTypeWal { + ctx.JSON( + http.StatusBadRequest, + gin.H{"error": "X-Upload-Type must be 'basebackup' or 'wal'"}, + ) + return + } + + walSegmentName := "" + if uploadType == backups_core.PgWalUploadTypeWal { + walSegmentName = ctx.GetHeader("X-Wal-Segment-Name") + if walSegmentName == "" { + ctx.JSON( + http.StatusBadRequest, + gin.H{"error": "X-Wal-Segment-Name is required for wal uploads"}, + ) + return + } + } + + if uploadType == backups_core.PgWalUploadTypeBasebackup { + if ctx.Query("fullBackupWalStartSegment") == "" || + ctx.Query("fullBackupWalStopSegment") == "" { + ctx.JSON( + http.StatusBadRequest, + gin.H{ + "error": "fullBackupWalStartSegment and fullBackupWalStopSegment are required for basebackup uploads", + }, + ) + return + } + } + + walSegmentSizeBytes := int64(0) + if raw := ctx.GetHeader("X-Wal-Segment-Size"); raw != "" { + parsed, parseErr := strconv.ParseInt(raw, 10, 64) + if parseErr != nil || parsed <= 0 { + ctx.JSON( + http.StatusBadRequest, + gin.H{"error": "X-Wal-Segment-Size must be a positive integer"}, + ) + return + } + + walSegmentSizeBytes = parsed + } + + gapResp, uploadErr := c.walService.UploadWal( + ctx.Request.Context(), + database, + uploadType, + walSegmentName, + ctx.Query("fullBackupWalStartSegment"), + ctx.Query("fullBackupWalStopSegment"), + walSegmentSizeBytes, + ctx.Request.Body, + ) + + if uploadErr != nil { + ctx.JSON(http.StatusInternalServerError, gin.H{"error": uploadErr.Error()}) + return + } + + if gapResp != nil { + if gapResp.Error == "no_full_backup" { + ctx.JSON(http.StatusBadRequest, gapResp) + return + } + + ctx.JSON(http.StatusConflict, gapResp) + return + } + + ctx.Status(http.StatusNoContent) +} + +// GetRestorePlan +// @Summary Get restore plan +// @Description Resolves the full backup and all required WAL segments needed for recovery. Validates the WAL chain is continuous. +// @Tags backups-wal +// @Produce json +// @Security AgentToken +// @Param backupId query string false "UUID of a specific full backup to restore from; defaults to the most recent" +// @Success 200 {object} backups_dto.GetRestorePlanResponse +// @Failure 400 {object} map[string]string "Broken WAL chain or no backups available" +// @Failure 401 {object} map[string]string +// @Failure 500 {object} map[string]string +// @Router /backups/postgres/wal/restore/plan [get] +func (c *PostgreWalBackupController) GetRestorePlan(ctx *gin.Context) { + database, err := c.getDatabase(ctx) + if err != nil { + ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"}) + return + } + + var backupID *uuid.UUID + if raw := ctx.Query("backupId"); raw != "" { + parsed, parseErr := uuid.Parse(raw) + if parseErr != nil { + ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backupId format"}) + return + } + + backupID = &parsed + } + + response, planErr, err := c.walService.GetRestorePlan(database, backupID) + if err != nil { + ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if planErr != nil { + ctx.JSON(http.StatusBadRequest, planErr) + return + } + + ctx.JSON(http.StatusOK, response) +} + +// DownloadBackupFile +// @Summary Download a backup or WAL segment file for restore +// @Description Retrieves the backup file by ID (validated against the authenticated database), decrypts it server-side if encrypted, and streams the zstd-compressed result to the agent +// @Tags backups-wal +// @Produce application/octet-stream +// @Security AgentToken +// @Param backupId query string true "Backup ID from the restore plan response" +// @Success 200 {file} file +// @Failure 400 {object} map[string]string +// @Failure 401 {object} map[string]string +// @Router /backups/postgres/wal/restore/download [get] +func (c *PostgreWalBackupController) DownloadBackupFile(ctx *gin.Context) { + database, err := c.getDatabase(ctx) + if err != nil { + ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"}) + return + } + + backupIDRaw := ctx.Query("backupId") + if backupIDRaw == "" { + ctx.JSON(http.StatusBadRequest, gin.H{"error": "backupId is required"}) + return + } + + backupID, err := uuid.Parse(backupIDRaw) + if err != nil { + ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backupId format"}) + return + } + + reader, err := c.walService.DownloadBackupFile(database, backupID) + if err != nil { + ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + defer func() { _ = reader.Close() }() + + ctx.Header("Content-Type", "application/octet-stream") + ctx.Status(http.StatusOK) + + _, _ = io.Copy(ctx.Writer, reader) +} + +func (c *PostgreWalBackupController) getDatabase( + ctx *gin.Context, +) (*databases.Database, error) { + token := ctx.GetHeader("Authorization") + return c.databaseService.GetDatabaseByAgentToken(token) +} diff --git a/backend/internal/features/backups/backups/controllers/postgres_wal_controller_test.go b/backend/internal/features/backups/backups/controllers/postgres_wal_controller_test.go new file mode 100644 index 0000000..da77d40 --- /dev/null +++ b/backend/internal/features/backups/backups/controllers/postgres_wal_controller_test.go @@ -0,0 +1,1224 @@ +package backups_controllers + +import ( + "bytes" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + backups_core "databasus-backend/internal/features/backups/backups/core" + backups_dto "databasus-backend/internal/features/backups/backups/dto" + backups_config "databasus-backend/internal/features/backups/config" + "databasus-backend/internal/features/databases" + "databasus-backend/internal/features/databases/databases/postgresql" + "databasus-backend/internal/features/intervals" + "databasus-backend/internal/features/storages" + local_storage "databasus-backend/internal/features/storages/models/local" + users_enums "databasus-backend/internal/features/users/enums" + users_testing "databasus-backend/internal/features/users/testing" + workspaces_controllers "databasus-backend/internal/features/workspaces/controllers" + workspaces_testing "databasus-backend/internal/features/workspaces/testing" + test_utils "databasus-backend/internal/util/testing" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_WalUpload_InProgressStatusSetBeforeStream(t *testing.T) { + router, db, storage, agentToken, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + // Upload a completed full backup so WAL upload chain validation passes. + uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010") + + pr, pw := io.Pipe() + req := newWalUploadRequest(pr, agentToken, "wal", "000000010000000100000011", "", "") + + w := httptest.NewRecorder() + done := make(chan struct{}) + go func() { + router.ServeHTTP(w, req) + close(done) + }() + + // The SaveFile call blocks until the body reader is closed — check status while it's open. + time.Sleep(150 * time.Millisecond) + + backups, err := backups_core.GetBackupRepository().FindByDatabaseID(db.ID) + require.NoError(t, err) + require.NotEmpty(t, backups) + assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status) + + // Allow the upload to finish. + _ = pw.Close() + <-done +} + +func Test_WalUpload_CompletedStatusAfterSuccessfulStream(t *testing.T) { + router, db, storage, agentToken, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010") + + body := bytes.NewReader([]byte("wal segment content")) + req := newWalUploadRequest(body, agentToken, "wal", "000000010000000100000011", "", "") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusNoContent, w.Code) + + WaitForBackupCompletion(t, db.ID, 1, 5*time.Second) + + backups, err := backups_core.GetBackupRepository().FindByDatabaseID(db.ID) + require.NoError(t, err) + + var walBackup *backups_core.Backup + for _, b := range backups { + if b.PgWalBackupType != nil && + *b.PgWalBackupType == backups_core.PgWalBackupTypeWalSegment { + walBackup = b + break + } + } + + require.NotNil(t, walBackup) + assert.Equal(t, backups_core.BackupStatusCompleted, walBackup.Status) +} + +func Test_WalUpload_FailedStatusWithErrorOnStreamError(t *testing.T) { + router, db, storage, agentToken, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010") + + pr, pw := io.Pipe() + req := newWalUploadRequest(pr, agentToken, "wal", "000000010000000100000011", "", "") + + w := httptest.NewRecorder() + done := make(chan struct{}) + go func() { + router.ServeHTTP(w, req) + close(done) + }() + + // Simulate a body read error mid-stream. + _ = pw.CloseWithError(errors.New("simulated network error")) + <-done + + backups, err := backups_core.GetBackupRepository().FindByDatabaseID(db.ID) + require.NoError(t, err) + + var walBackup *backups_core.Backup + for _, b := range backups { + if b.PgWalBackupType != nil && + *b.PgWalBackupType == backups_core.PgWalBackupTypeWalSegment { + walBackup = b + break + } + } + + require.NotNil(t, walBackup) + assert.Equal(t, backups_core.BackupStatusFailed, walBackup.Status) + assert.NotNil(t, walBackup.FailMessage) +} + +func Test_WalUpload_Basebackup_MissingWalSegments_Returns400(t *testing.T) { + router, db, storage, agentToken, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + body := bytes.NewReader([]byte("basebackup content")) + req := newWalUploadRequest(body, agentToken, backups_core.PgWalUploadTypeBasebackup, "", "", "") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) +} + +func Test_WalUpload_WalSegment_NoFullBackup_Returns400(t *testing.T) { + router, db, storage, agentToken, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + // No full backup inserted — chain anchor is missing. + body := bytes.NewReader([]byte("wal content")) + req := newWalUploadRequest(body, agentToken, "wal", "000000010000000100000001", "", "") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) + + var resp backups_dto.UploadGapResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) + assert.Equal(t, "no_full_backup", resp.Error) +} + +func Test_WalUpload_WalSegment_GapDetected_Returns409WithExpectedAndReceived(t *testing.T) { + router, db, storage, agentToken, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + // Full backup stops at ...0010; upload one WAL segment at ...0011. + uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010") + uploadWalSegment(t, router, agentToken, "000000010000000100000011") + + // Send ...0013 — should be rejected because ...0012 is missing. + body := bytes.NewReader([]byte("wal content")) + req := newWalUploadRequest(body, agentToken, "wal", "000000010000000100000013", "", "") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusConflict, w.Code) + + var resp backups_dto.UploadGapResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) + assert.Equal(t, "gap_detected", resp.Error) + assert.Equal(t, "000000010000000100000012", resp.ExpectedSegmentName) + assert.Equal(t, "000000010000000100000013", resp.ReceivedSegmentName) +} + +func Test_WalUpload_WalSegment_DuplicateSegment_Returns200Idempotent(t *testing.T) { + router, db, storage, agentToken, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010") + + // Upload ...0011 once. + body1 := bytes.NewReader([]byte("wal content")) + req1 := newWalUploadRequest(body1, agentToken, "wal", "000000010000000100000011", "", "") + w1 := httptest.NewRecorder() + router.ServeHTTP(w1, req1) + require.Equal(t, http.StatusNoContent, w1.Code) + + // Upload the same segment again — must return 204 (idempotent). + body2 := bytes.NewReader([]byte("wal content")) + req2 := newWalUploadRequest(body2, agentToken, "wal", "000000010000000100000011", "", "") + w2 := httptest.NewRecorder() + router.ServeHTTP(w2, req2) + + assert.Equal(t, http.StatusNoContent, w2.Code) + + // Ensure only ONE WAL segment record exists (no duplicate created). + backups, err := backups_core.GetBackupRepository().FindByDatabaseID(db.ID) + require.NoError(t, err) + + walCount := 0 + for _, b := range backups { + if b.PgWalBackupType != nil && + *b.PgWalBackupType == backups_core.PgWalBackupTypeWalSegment { + walCount++ + } + } + + assert.Equal(t, 1, walCount, "duplicate upload must not create a second backup record") +} + +func Test_WalUpload_WalSegment_ValidNextSegment_Returns200AndCreatesRecord(t *testing.T) { + router, db, storage, agentToken, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010") + + // First WAL segment after the full backup stop segment. + body := bytes.NewReader([]byte("wal segment data")) + req := newWalUploadRequest(body, agentToken, "wal", "000000010000000100000011", "", "") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusNoContent, w.Code) + + WaitForBackupCompletion(t, db.ID, 1, 5*time.Second) + + backups, err := backups_core.GetBackupRepository().FindByDatabaseID(db.ID) + require.NoError(t, err) + + var walBackup *backups_core.Backup + for _, b := range backups { + if b.PgWalBackupType != nil && + *b.PgWalBackupType == backups_core.PgWalBackupTypeWalSegment { + walBackup = b + break + } + } + + require.NotNil(t, walBackup) + assert.Equal(t, backups_core.BackupStatusCompleted, walBackup.Status) + require.NotNil(t, walBackup.PgWalSegmentName) + assert.Equal(t, "000000010000000100000011", *walBackup.PgWalSegmentName) +} + +func Test_ReportError_ValidTokenAndError_CreatesFailedBackupRecord(t *testing.T) { + router, db, storage, agentToken, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + errorMsg := "failed to parse pg_basebackup stderr: start WAL location not found" + body, _ := json.Marshal(map[string]string{"error": errorMsg}) + + req, _ := http.NewRequest( + http.MethodPost, + "/api/v1/backups/postgres/wal/error", + bytes.NewReader(body), + ) + req.Header.Set("Authorization", agentToken) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + + backups, err := backups_core.GetBackupRepository().FindByDatabaseID(db.ID) + require.NoError(t, err) + require.NotEmpty(t, backups) + + assert.Equal(t, backups_core.BackupStatusFailed, backups[0].Status) + require.NotNil(t, backups[0].FailMessage) + assert.Equal(t, errorMsg, *backups[0].FailMessage) +} + +func Test_ReportError_WithInvalidToken_ReturnsUnauthorized(t *testing.T) { + router, db, storage, _, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + body, _ := json.Marshal(map[string]string{"error": "some error"}) + + req, _ := http.NewRequest( + http.MethodPost, + "/api/v1/backups/postgres/wal/error", + bytes.NewReader(body), + ) + req.Header.Set("Authorization", "invalid-token") + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusUnauthorized, w.Code) +} + +func Test_ReportError_WithMissingErrorField_ReturnsBadRequest(t *testing.T) { + router, db, storage, agentToken, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + body, _ := json.Marshal(map[string]string{}) + + req, _ := http.NewRequest( + http.MethodPost, + "/api/v1/backups/postgres/wal/error", + bytes.NewReader(body), + ) + req.Header.Set("Authorization", agentToken) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) +} + +func Test_GetNextFullBackupTime_WithValidToken_NoFullBackup_ReturnsNull(t *testing.T) { + router, db, storage, agentToken, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + var response backups_dto.GetNextFullBackupTimeResponse + test_utils.MakeGetRequestAndUnmarshal( + t, + router, + "/api/v1/backups/postgres/wal/next-full-backup-time", + agentToken, + http.StatusOK, + &response, + ) + + assert.Nil(t, response.NextFullBackupTime, "should be nil when no full backup exists") +} + +func Test_GetNextFullBackupTime_WithValidToken_HasFullBackup_ReturnsTime(t *testing.T) { + cronExpr := "0 3 * * *" + customTime := "14:30" + + tests := []struct { + name string + interval *intervals.Interval + expectedHour int + expectedMin int + checkHourMin bool + }{ + { + name: "daily interval returns time at 04:00", + interval: nil, // use default (daily 04:00) + expectedHour: 4, + expectedMin: 0, + checkHourMin: true, + }, + { + name: "hourly interval returns future time", + interval: &intervals.Interval{ + Interval: intervals.IntervalHourly, + }, + checkHourMin: false, + }, + { + name: "cron interval returns future time", + interval: &intervals.Interval{ + Interval: intervals.IntervalCron, + CronExpression: &cronExpr, + }, + checkHourMin: false, + }, + { + name: "daily interval with custom time 14:30", + interval: &intervals.Interval{ + Interval: intervals.IntervalDaily, + TimeOfDay: &customTime, + }, + expectedHour: 14, + expectedMin: 30, + checkHourMin: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + router, db, storage, agentToken, ownerToken := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + if tt.interval != nil { + var cfg backups_config.BackupConfig + test_utils.MakeGetRequestAndUnmarshal( + t, router, + "/api/v1/backup-configs/database/"+db.ID.String(), + "Bearer "+ownerToken, + http.StatusOK, &cfg, + ) + + cfg.BackupInterval = tt.interval + + test_utils.MakePostRequestAndUnmarshal( + t, router, + "/api/v1/backup-configs/save", + "Bearer "+ownerToken, + cfg, + http.StatusOK, &cfg, + ) + } + + uploadBasebackup( + t, + router, + agentToken, + "000000010000000100000001", + "000000010000000100000010", + ) + + now := time.Now().UTC() + + var response backups_dto.GetNextFullBackupTimeResponse + test_utils.MakeGetRequestAndUnmarshal( + t, + router, + "/api/v1/backups/postgres/wal/next-full-backup-time", + agentToken, + http.StatusOK, + &response, + ) + + require.NotNil(t, response.NextFullBackupTime) + nextTime := response.NextFullBackupTime.UTC() + + if tt.checkHourMin { + assert.Equal(t, tt.expectedHour, nextTime.Hour(), "expected hour") + assert.Equal(t, tt.expectedMin, nextTime.Minute(), "expected minute") + } + + assert.True(t, + nextTime.After(now.Add(-1*time.Minute)), + "next backup time should not be in the past", + ) + assert.True(t, + nextTime.Before(now.Add(25*time.Hour)), + "next backup time should be within 25 hours", + ) + }) + } +} + +func Test_GetNextFullBackupTime_WalSegmentAfterFullBackup_DoesNotImpactTime(t *testing.T) { + router, db, storage, agentToken, ownerToken := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + setHourlyInterval(t, router, db.ID, ownerToken) + + // Upload basebackup via API. + bbBody := bytes.NewReader([]byte("basebackup content")) + bbReq := newWalUploadRequest( + bbBody, agentToken, backups_core.PgWalUploadTypeBasebackup, "", + "000000010000000100000001", "000000010000000100000010", + ) + bbW := httptest.NewRecorder() + router.ServeHTTP(bbW, bbReq) + require.Equal(t, http.StatusNoContent, bbW.Code) + + // Shift the full backup's CreatedAt to 2 hours ago. + twoHoursAgo := time.Now().UTC().Add(-2 * time.Hour) + updateLastFullBackupTime(t, db.ID, twoHoursAgo) + + // Upload WAL segment via API. + walBody := bytes.NewReader([]byte("wal segment content")) + walReq := newWalUploadRequest( + walBody, agentToken, backups_core.PgWalUploadTypeWal, + "000000010000000100000011", "", "", + ) + walW := httptest.NewRecorder() + router.ServeHTTP(walW, walReq) + require.Equal(t, http.StatusNoContent, walW.Code) + + var response backups_dto.GetNextFullBackupTimeResponse + test_utils.MakeGetRequestAndUnmarshal( + t, + router, + "/api/v1/backups/postgres/wal/next-full-backup-time", + agentToken, + http.StatusOK, + &response, + ) + + require.NotNil(t, response.NextFullBackupTime) + nextTime := response.NextFullBackupTime.UTC() + + // Hourly: nextTime = fullBackup.CreatedAt + 1h ≈ 1 hour ago (already past). + // WAL segment should not have shifted it forward. + expectedApprox := twoHoursAgo.Add(time.Hour) + assert.WithinDuration(t, expectedApprox, nextTime, 5*time.Second, + "next time should be based on full backup, not WAL segment", + ) +} + +func Test_GetNextFullBackupTime_FailedBasebackup_DoesNotImpactTime(t *testing.T) { + router, db, storage, agentToken, ownerToken := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + setHourlyInterval(t, router, db.ID, ownerToken) + + // Upload a successful basebackup via API. + bbBody := bytes.NewReader([]byte("basebackup content")) + bbReq := newWalUploadRequest( + bbBody, agentToken, backups_core.PgWalUploadTypeBasebackup, "", + "000000010000000100000001", "000000010000000100000010", + ) + bbW := httptest.NewRecorder() + router.ServeHTTP(bbW, bbReq) + require.Equal(t, http.StatusNoContent, bbW.Code) + + // Shift the full backup's CreatedAt to 2 hours ago. + twoHoursAgo := time.Now().UTC().Add(-2 * time.Hour) + updateLastFullBackupTime(t, db.ID, twoHoursAgo) + + // Report an error via the error endpoint. + errorMsg := "pg_basebackup failed: connection refused" + errBody, _ := json.Marshal(map[string]string{"error": errorMsg}) + errReq, _ := http.NewRequest( + http.MethodPost, + "/api/v1/backups/postgres/wal/error", + bytes.NewReader(errBody), + ) + errReq.Header.Set("Authorization", agentToken) + errReq.Header.Set("Content-Type", "application/json") + errW := httptest.NewRecorder() + router.ServeHTTP(errW, errReq) + require.Equal(t, http.StatusOK, errW.Code) + + var response backups_dto.GetNextFullBackupTimeResponse + test_utils.MakeGetRequestAndUnmarshal( + t, + router, + "/api/v1/backups/postgres/wal/next-full-backup-time", + agentToken, + http.StatusOK, + &response, + ) + + require.NotNil(t, response.NextFullBackupTime) + nextTime := response.NextFullBackupTime.UTC() + + // Hourly: nextTime = completedFullBackup.CreatedAt + 1h ≈ 1 hour ago. + // The error report should not have shifted it forward. + expectedApprox := twoHoursAgo.Add(time.Hour) + assert.WithinDuration(t, expectedApprox, nextTime, 5*time.Second, + "next time should be based on completed full backup, not error report", + ) +} + +func Test_GetNextFullBackupTime_NewCompletedFullBackup_ImpactsTime(t *testing.T) { + router, db, storage, agentToken, ownerToken := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + setHourlyInterval(t, router, db.ID, ownerToken) + + // Upload first basebackup via API. + bb1 := bytes.NewReader([]byte("first basebackup")) + bb1Req := newWalUploadRequest( + bb1, agentToken, backups_core.PgWalUploadTypeBasebackup, "", + "000000010000000100000001", "000000010000000100000010", + ) + bb1W := httptest.NewRecorder() + router.ServeHTTP(bb1W, bb1Req) + require.Equal(t, http.StatusNoContent, bb1W.Code) + + // Shift the first backup's CreatedAt to 3 hours ago. + threeHoursAgo := time.Now().UTC().Add(-3 * time.Hour) + updateLastFullBackupTime(t, db.ID, threeHoursAgo) + + var firstResponse backups_dto.GetNextFullBackupTimeResponse + test_utils.MakeGetRequestAndUnmarshal( + t, + router, + "/api/v1/backups/postgres/wal/next-full-backup-time", + agentToken, + http.StatusOK, + &firstResponse, + ) + + require.NotNil(t, firstResponse.NextFullBackupTime) + firstNextTime := firstResponse.NextFullBackupTime.UTC() + + // First result: 3h ago + 1h = 2h ago (in the past). + assert.True(t, firstNextTime.Before(time.Now().UTC()), + "first next time should be in the past (old backup)", + ) + + // Upload second basebackup via API (created now). + bb2 := bytes.NewReader([]byte("second basebackup")) + bb2Req := newWalUploadRequest( + bb2, agentToken, backups_core.PgWalUploadTypeBasebackup, "", + "000000010000000100000011", "000000010000000100000020", + ) + bb2W := httptest.NewRecorder() + router.ServeHTTP(bb2W, bb2Req) + require.Equal(t, http.StatusNoContent, bb2W.Code) + + var secondResponse backups_dto.GetNextFullBackupTimeResponse + test_utils.MakeGetRequestAndUnmarshal( + t, + router, + "/api/v1/backups/postgres/wal/next-full-backup-time", + agentToken, + http.StatusOK, + &secondResponse, + ) + + require.NotNil(t, secondResponse.NextFullBackupTime) + secondNextTime := secondResponse.NextFullBackupTime.UTC() + + // Second result: now + 1h (in the future). + assert.True(t, secondNextTime.After(firstNextTime), + "new full backup should shift next time forward", + ) + assert.True(t, secondNextTime.After(time.Now().UTC()), + "second next time should be in the future", + ) +} + +func Test_GetNextFullBackupTime_WithInvalidToken_ReturnsUnauthorized(t *testing.T) { + router, db, storage, _, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + resp := test_utils.MakeGetRequest( + t, + router, + "/api/v1/backups/postgres/wal/next-full-backup-time", + "invalid-token", + http.StatusUnauthorized, + ) + + assert.Contains(t, string(resp.Body), "invalid agent token") +} + +func Test_GetRestorePlan_NoFullBackup_Returns400(t *testing.T) { + router, db, storage, agentToken, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + resp := test_utils.MakeGetRequest( + t, router, + "/api/v1/backups/postgres/wal/restore/plan", + agentToken, + http.StatusBadRequest, + ) + + var errResp backups_dto.GetRestorePlanErrorResponse + require.NoError(t, json.Unmarshal(resp.Body, &errResp)) + assert.Equal(t, "no_backups", errResp.Error) +} + +func Test_GetRestorePlan_WithFullBackupOnly_Returns200(t *testing.T) { + router, db, storage, agentToken, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010") + + var response backups_dto.GetRestorePlanResponse + test_utils.MakeGetRequestAndUnmarshal( + t, router, + "/api/v1/backups/postgres/wal/restore/plan", + agentToken, + http.StatusOK, + &response, + ) + + assert.NotEqual(t, uuid.Nil, response.FullBackup.BackupID) + assert.Equal(t, "000000010000000100000001", response.FullBackup.FullBackupWalStartSegment) + assert.Equal(t, "000000010000000100000010", response.FullBackup.FullBackupWalStopSegment) + assert.Empty(t, response.WalSegments) + assert.Greater(t, response.TotalSizeBytes, int64(0)) +} + +func Test_GetRestorePlan_WithFullBackupAndWalSegments_Returns200(t *testing.T) { + router, db, storage, agentToken, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010") + uploadWalSegment(t, router, agentToken, "000000010000000100000011") + uploadWalSegment(t, router, agentToken, "000000010000000100000012") + uploadWalSegment(t, router, agentToken, "000000010000000100000013") + + var response backups_dto.GetRestorePlanResponse + test_utils.MakeGetRequestAndUnmarshal( + t, router, + "/api/v1/backups/postgres/wal/restore/plan", + agentToken, + http.StatusOK, + &response, + ) + + assert.NotEqual(t, uuid.Nil, response.FullBackup.BackupID) + require.Len(t, response.WalSegments, 3) + assert.Equal(t, "000000010000000100000011", response.WalSegments[0].SegmentName) + assert.Equal(t, "000000010000000100000012", response.WalSegments[1].SegmentName) + assert.Equal(t, "000000010000000100000013", response.WalSegments[2].SegmentName) + assert.Equal(t, "000000010000000100000013", response.LatestAvailableSegment) + assert.Greater(t, response.TotalSizeBytes, int64(0)) + + for _, seg := range response.WalSegments { + assert.NotEqual(t, uuid.Nil, seg.BackupID) + } +} + +func Test_GetRestorePlan_WithSpecificBackupId_Returns200(t *testing.T) { + router, db, storage, agentToken, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010") + + firstBackup, err := backups_core.GetBackupRepository(). + FindLastCompletedFullWalBackupByDatabaseID(db.ID) + require.NoError(t, err) + require.NotNil(t, firstBackup) + + uploadWalSegment(t, router, agentToken, "000000010000000100000011") + + uploadBasebackup(t, router, agentToken, "000000010000000100000011", "000000010000000100000020") + + var response backups_dto.GetRestorePlanResponse + test_utils.MakeGetRequestAndUnmarshal( + t, router, + "/api/v1/backups/postgres/wal/restore/plan?backupId="+firstBackup.ID.String(), + agentToken, + http.StatusOK, + &response, + ) + + assert.Equal(t, firstBackup.ID, response.FullBackup.BackupID) + assert.Equal(t, "000000010000000100000001", response.FullBackup.FullBackupWalStartSegment) + assert.Equal(t, "000000010000000100000010", response.FullBackup.FullBackupWalStopSegment) +} + +func Test_GetRestorePlan_WithInvalidBackupId_Returns400(t *testing.T) { + router, db, storage, agentToken, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010") + + nonExistentID := uuid.New() + + resp := test_utils.MakeGetRequest( + t, router, + "/api/v1/backups/postgres/wal/restore/plan?backupId="+nonExistentID.String(), + agentToken, + http.StatusBadRequest, + ) + + var errResp backups_dto.GetRestorePlanErrorResponse + require.NoError(t, json.Unmarshal(resp.Body, &errResp)) + assert.Equal(t, "no_backups", errResp.Error) +} + +func Test_GetRestorePlan_WithInvalidToken_Returns401(t *testing.T) { + router, db, storage, _, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + resp := test_utils.MakeGetRequest( + t, router, + "/api/v1/backups/postgres/wal/restore/plan", + "invalid-token", + http.StatusUnauthorized, + ) + + assert.Contains(t, string(resp.Body), "invalid agent token") +} + +func Test_GetRestorePlan_WalChainBroken_Returns400(t *testing.T) { + router, db, storage, agentToken, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010") + uploadWalSegment(t, router, agentToken, "000000010000000100000011") + uploadWalSegment(t, router, agentToken, "000000010000000100000012") + uploadWalSegment(t, router, agentToken, "000000010000000100000013") + + middleSeg, err := backups_core.GetBackupRepository().FindWalSegmentByName( + db.ID, "000000010000000100000012", + ) + require.NoError(t, err) + require.NotNil(t, middleSeg) + require.NoError(t, backups_core.GetBackupRepository().DeleteByID(middleSeg.ID)) + + resp := test_utils.MakeGetRequest( + t, router, + "/api/v1/backups/postgres/wal/restore/plan", + agentToken, + http.StatusBadRequest, + ) + + var errResp backups_dto.GetRestorePlanErrorResponse + require.NoError(t, json.Unmarshal(resp.Body, &errResp)) + assert.Equal(t, "wal_chain_broken", errResp.Error) + assert.Equal(t, "000000010000000100000011", errResp.LastContiguousSegment) +} + +func Test_GetRestorePlan_WithInvalidBackupIdFormat_Returns400(t *testing.T) { + router, db, storage, agentToken, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + resp := test_utils.MakeGetRequest( + t, router, + "/api/v1/backups/postgres/wal/restore/plan?backupId=not-a-uuid", + agentToken, + http.StatusBadRequest, + ) + + assert.Contains(t, string(resp.Body), "invalid backupId format") +} + +func Test_DownloadRestoreFile_UploadThenDownload_ContentMatches(t *testing.T) { + tests := []struct { + name string + encryption backups_config.BackupEncryption + }{ + { + name: "unencrypted", + encryption: backups_config.BackupEncryptionNone, + }, + { + name: "encrypted", + encryption: backups_config.BackupEncryptionEncrypted, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + router, db, storage, agentToken, ownerToken := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + setEncryption(t, router, db.ID, ownerToken, tt.encryption) + + uploadContent := "test-basebackup-content-for-download" + body := bytes.NewReader([]byte(uploadContent)) + req := newWalUploadRequest( + body, agentToken, backups_core.PgWalUploadTypeBasebackup, "", + "000000010000000100000001", "000000010000000100000010", + ) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + require.Equal(t, http.StatusNoContent, w.Code) + + WaitForBackupCompletion(t, db.ID, 0, 5*time.Second) + + var planResp backups_dto.GetRestorePlanResponse + test_utils.MakeGetRequestAndUnmarshal( + t, router, + "/api/v1/backups/postgres/wal/restore/plan", + agentToken, + http.StatusOK, + &planResp, + ) + + require.NotEqual(t, uuid.Nil, planResp.FullBackup.BackupID) + + downloadResp := test_utils.MakeGetRequest( + t, + router, + "/api/v1/backups/postgres/wal/restore/download?backupId="+planResp.FullBackup.BackupID.String(), + agentToken, + http.StatusOK, + ) + + assert.Equal(t, uploadContent, string(downloadResp.Body)) + }) + } +} + +func Test_DownloadRestoreFile_WalSegment_UploadThenDownload_ContentMatches(t *testing.T) { + router, db, storage, agentToken, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + uploadBasebackup(t, router, agentToken, "000000010000000100000001", "000000010000000100000010") + + walContent := "test-wal-segment-content-for-download" + body := bytes.NewReader([]byte(walContent)) + req := newWalUploadRequest(body, agentToken, "wal", "000000010000000100000011", "", "") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + require.Equal(t, http.StatusNoContent, w.Code) + + WaitForBackupCompletion(t, db.ID, 1, 5*time.Second) + + var planResp backups_dto.GetRestorePlanResponse + test_utils.MakeGetRequestAndUnmarshal( + t, router, + "/api/v1/backups/postgres/wal/restore/plan", + agentToken, + http.StatusOK, + &planResp, + ) + + require.Len(t, planResp.WalSegments, 1) + + downloadResp := test_utils.MakeGetRequest( + t, + router, + "/api/v1/backups/postgres/wal/restore/download?backupId="+planResp.WalSegments[0].BackupID.String(), + agentToken, + http.StatusOK, + ) + + assert.Equal(t, walContent, string(downloadResp.Body)) +} + +func Test_DownloadRestoreFile_InvalidBackupId_Returns400(t *testing.T) { + router, db, storage, agentToken, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + nonExistentID := uuid.New() + + resp := test_utils.MakeGetRequest( + t, router, + "/api/v1/backups/postgres/wal/restore/download?backupId="+nonExistentID.String(), + agentToken, + http.StatusBadRequest, + ) + + assert.Contains(t, string(resp.Body), "backup not found") +} + +func Test_DownloadRestoreFile_InvalidToken_Returns401(t *testing.T) { + router, db, storage, _, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + resp := test_utils.MakeGetRequest( + t, router, + "/api/v1/backups/postgres/wal/restore/download?backupId="+uuid.New().String(), + "invalid-token", + http.StatusUnauthorized, + ) + + assert.Contains(t, string(resp.Body), "invalid agent token") +} + +func Test_DownloadRestoreFile_BackupFromOtherDatabase_Returns400(t *testing.T) { + router, db1, storage1, agentToken1, _ := createWalTestSetup(t) + defer removeWalTestSetup(db1, storage1) + + _, db2, storage2, agentToken2, _ := createWalTestSetup(t) + defer removeWalTestSetup(db2, storage2) + + uploadBasebackup(t, router, agentToken1, "000000010000000100000001", "000000010000000100000010") + + WaitForBackupCompletion(t, db1.ID, 0, 5*time.Second) + + var planResp backups_dto.GetRestorePlanResponse + test_utils.MakeGetRequestAndUnmarshal( + t, router, + "/api/v1/backups/postgres/wal/restore/plan", + agentToken1, + http.StatusOK, + &planResp, + ) + + resp := test_utils.MakeGetRequest( + t, + router, + "/api/v1/backups/postgres/wal/restore/download?backupId="+planResp.FullBackup.BackupID.String(), + agentToken2, + http.StatusBadRequest, + ) + + assert.Contains(t, string(resp.Body), "backup does not belong to this database") +} + +func Test_DownloadRestoreFile_MissingBackupId_Returns400(t *testing.T) { + router, db, storage, agentToken, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + resp := test_utils.MakeGetRequest( + t, router, + "/api/v1/backups/postgres/wal/restore/download", + agentToken, + http.StatusBadRequest, + ) + + assert.Contains(t, string(resp.Body), "backupId is required") +} + +func Test_DownloadRestoreFile_InvalidBackupIdFormat_Returns400(t *testing.T) { + router, db, storage, agentToken, _ := createWalTestSetup(t) + defer removeWalTestSetup(db, storage) + + resp := test_utils.MakeGetRequest( + t, router, + "/api/v1/backups/postgres/wal/restore/download?backupId=not-a-uuid", + agentToken, + http.StatusBadRequest, + ) + + assert.Contains(t, string(resp.Body), "invalid backupId format") +} + +func createWalTestRouter() *gin.Engine { + router := workspaces_testing.CreateTestRouter( + workspaces_controllers.GetWorkspaceController(), + workspaces_controllers.GetMembershipController(), + databases.GetDatabaseController(), + backups_config.GetBackupConfigController(), + GetBackupController(), + ) + + v1 := router.Group("/api/v1") + GetPostgresWalBackupController().RegisterRoutes(v1) + + return router +} + +func createWalTestSetup(t *testing.T) ( + router *gin.Engine, + db *databases.Database, + storage *storages.Storage, + agentToken string, + ownerToken string, +) { + t.Helper() + + router = createWalTestRouter() + + owner := users_testing.CreateTestUser(users_enums.UserRoleMember) + workspace := workspaces_testing.CreateTestWorkspace("WAL Test Workspace", owner, router) + + db = createTestDatabase("WAL Test DB", workspace.ID, owner.Token, router) + + // Set backup type to WAL_V1 so the WAL service accepts requests. + db.Postgresql.BackupType = postgresql.PostgresBackupTypeWalV1 + dbRepo := &databases.DatabaseRepository{} + if _, err := dbRepo.Save(db); err != nil { + t.Fatalf("failed to update database backup type: %v", err) + } + + storage = &storages.Storage{ + WorkspaceID: workspace.ID, + Type: storages.StorageTypeLocal, + Name: "WAL Test Storage " + uuid.New().String(), + LocalStorage: &local_storage.LocalStorage{}, + } + + repo := &storages.StorageRepository{} + storage, err := repo.Save(storage) + if err != nil { + t.Fatalf("failed to create test storage: %v", err) + } + + configService := backups_config.GetBackupConfigService() + cfg, err := configService.GetBackupConfigByDbId(db.ID) + if err != nil { + t.Fatalf("failed to get backup config: %v", err) + } + + cfg.IsBackupsEnabled = true + cfg.StorageID = &storage.ID + cfg.Storage = storage + _, err = configService.SaveBackupConfig(cfg) + if err != nil { + t.Fatalf("failed to save backup config: %v", err) + } + + var tokenResp map[string]string + test_utils.MakePostRequestAndUnmarshal( + t, + router, + "/api/v1/databases/"+db.ID.String()+"/regenerate-token", + "Bearer "+owner.Token, + nil, + http.StatusOK, + &tokenResp, + ) + + agentToken = tokenResp["token"] + ownerToken = owner.Token + + return router, db, storage, agentToken, ownerToken +} + +func removeWalTestSetup(db *databases.Database, storage *storages.Storage) { + databases.RemoveTestDatabase(db) + storages.RemoveTestStorage(storage.ID) +} + +func newWalUploadRequest( + body io.Reader, + agentToken string, + uploadType backups_core.PgWalUploadType, + walSegmentName string, + walStart string, + walStop string, +) *http.Request { + url := "/api/v1/backups/postgres/wal/upload" + if walStart != "" || walStop != "" { + url += "?fullBackupWalStartSegment=" + walStart + "&fullBackupWalStopSegment=" + walStop + } + + req, err := http.NewRequest(http.MethodPost, url, body) + if err != nil { + panic(err) + } + + req.Header.Set("Authorization", agentToken) + req.Header.Set("Content-Type", "application/octet-stream") + req.Header.Set("X-Upload-Type", string(uploadType)) + + if walSegmentName != "" { + req.Header.Set("X-Wal-Segment-Name", walSegmentName) + } + + return req +} + +func uploadBasebackup( + t *testing.T, + router *gin.Engine, + agentToken string, + walStart string, + walStop string, +) { + t.Helper() + + body := bytes.NewReader([]byte("test-basebackup-content")) + req := newWalUploadRequest( + body, agentToken, backups_core.PgWalUploadTypeBasebackup, "", + walStart, walStop, + ) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusNoContent, w.Code) +} + +func uploadWalSegment( + t *testing.T, + router *gin.Engine, + agentToken string, + segmentName string, +) { + t.Helper() + + body := bytes.NewReader([]byte("test-wal-segment-content")) + req := newWalUploadRequest( + body, agentToken, backups_core.PgWalUploadTypeWal, segmentName, "", "", + ) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusNoContent, w.Code) +} + +func updateLastFullBackupTime(t *testing.T, databaseID uuid.UUID, createdAt time.Time) { + t.Helper() + + repo := backups_core.GetBackupRepository() + + backup, err := repo.FindLastCompletedFullWalBackupByDatabaseID(databaseID) + if err != nil { + t.Fatalf("updateLastFullBackupTime: find: %v", err) + } + + require.NotNil(t, backup, "no completed full backup found to update") + + backup.CreatedAt = createdAt + if err := repo.Save(backup); err != nil { + t.Fatalf("updateLastFullBackupTime: save: %v", err) + } +} + +func setEncryption( + t *testing.T, + router *gin.Engine, + databaseID uuid.UUID, + ownerToken string, + encryption backups_config.BackupEncryption, +) { + t.Helper() + + var cfg backups_config.BackupConfig + test_utils.MakeGetRequestAndUnmarshal( + t, router, + "/api/v1/backup-configs/database/"+databaseID.String(), + "Bearer "+ownerToken, + http.StatusOK, &cfg, + ) + + cfg.Encryption = encryption + + test_utils.MakePostRequestAndUnmarshal( + t, router, + "/api/v1/backup-configs/save", + "Bearer "+ownerToken, + cfg, + http.StatusOK, &cfg, + ) +} + +func setHourlyInterval(t *testing.T, router *gin.Engine, databaseID uuid.UUID, ownerToken string) { + t.Helper() + + var cfg backups_config.BackupConfig + test_utils.MakeGetRequestAndUnmarshal( + t, router, + "/api/v1/backup-configs/database/"+databaseID.String(), + "Bearer "+ownerToken, + http.StatusOK, &cfg, + ) + + cfg.BackupInterval = &intervals.Interval{Interval: intervals.IntervalHourly} + + test_utils.MakePostRequestAndUnmarshal( + t, router, + "/api/v1/backup-configs/save", + "Bearer "+ownerToken, + cfg, + http.StatusOK, &cfg, + ) +} diff --git a/backend/internal/features/backups/backups/testing.go b/backend/internal/features/backups/backups/controllers/testing.go similarity index 96% rename from backend/internal/features/backups/backups/testing.go rename to backend/internal/features/backups/backups/controllers/testing.go index 152a8ea..71fa33d 100644 --- a/backend/internal/features/backups/backups/testing.go +++ b/backend/internal/features/backups/backups/controllers/testing.go @@ -1,4 +1,4 @@ -package backups +package backups_controllers import ( "testing" @@ -41,7 +41,7 @@ func WaitForBackupCompletion( deadline := time.Now().UTC().Add(timeout) for time.Now().UTC().Before(deadline) { - backups, err := backupRepository.FindByDatabaseID(databaseID) + backups, err := backups_core.GetBackupRepository().FindByDatabaseID(databaseID) if err != nil { t.Logf("WaitForBackupCompletion: error finding backups: %v", err) time.Sleep(50 * time.Millisecond) diff --git a/backend/internal/features/backups/backups/core/di.go b/backend/internal/features/backups/backups/core/di.go new file mode 100644 index 0000000..5089568 --- /dev/null +++ b/backend/internal/features/backups/backups/core/di.go @@ -0,0 +1,7 @@ +package backups_core + +var backupRepository = &BackupRepository{} + +func GetBackupRepository() *BackupRepository { + return backupRepository +} diff --git a/backend/internal/features/backups/backups/core/enums.go b/backend/internal/features/backups/backups/core/enums.go index b94e76c..cc422b4 100644 --- a/backend/internal/features/backups/backups/core/enums.go +++ b/backend/internal/features/backups/backups/core/enums.go @@ -8,3 +8,10 @@ const ( BackupStatusFailed BackupStatus = "FAILED" BackupStatusCanceled BackupStatus = "CANCELED" ) + +type PgWalUploadType string + +const ( + PgWalUploadTypeBasebackup PgWalUploadType = "basebackup" + PgWalUploadTypeWal PgWalUploadType = "wal" +) diff --git a/backend/internal/features/backups/backups/core/model.go b/backend/internal/features/backups/backups/core/model.go index ddef4ff..cd8e214 100644 --- a/backend/internal/features/backups/backups/core/model.go +++ b/backend/internal/features/backups/backups/core/model.go @@ -1,12 +1,22 @@ package backups_core import ( - backups_config "databasus-backend/internal/features/backups/config" + "fmt" "time" + backups_config "databasus-backend/internal/features/backups/config" + files_utils "databasus-backend/internal/util/files" + "github.com/google/uuid" ) +type PgWalBackupType string + +const ( + PgWalBackupTypeFullBackup PgWalBackupType = "PG_FULL_BACKUP" + PgWalBackupTypeWalSegment PgWalBackupType = "PG_WAL_SEGMENT" +) + type Backup struct { ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"` FileName string `json:"fileName" gorm:"column:file_name;type:text;not null"` @@ -26,5 +36,23 @@ type Backup struct { EncryptionIV *string `json:"-" gorm:"column:encryption_iv"` Encryption backups_config.BackupEncryption `json:"encryption" gorm:"column:encryption;type:text;not null;default:'NONE'"` + // Postgres WAL backup specific fields + PgWalBackupType *PgWalBackupType `json:"pgWalBackupType" gorm:"column:pg_wal_backup_type;type:text"` + PgFullBackupWalStartSegmentName *string `json:"pgFullBackupWalStartSegmentName" gorm:"column:pg_wal_start_segment;type:text"` + PgFullBackupWalStopSegmentName *string `json:"pgFullBackupWalStopSegmentName" gorm:"column:pg_wal_stop_segment;type:text"` + PgVersion *string `json:"pgVersion" gorm:"column:pg_version;type:text"` + PgWalSegmentName *string `json:"pgWalSegmentName" gorm:"column:pg_wal_segment_name;type:text"` + CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"` } + +func (b *Backup) GenerateFilename(dbName string) { + timestamp := time.Now().UTC() + + b.FileName = fmt.Sprintf( + "%s-%s-%s", + files_utils.SanitizeFilename(dbName), + timestamp.Format("20060102-150405"), + b.ID.String(), + ) +} diff --git a/backend/internal/features/backups/backups/core/repository.go b/backend/internal/features/backups/backups/core/repository.go index dfa81f0..a1c8247 100644 --- a/backend/internal/features/backups/backups/core/repository.go +++ b/backend/internal/features/backups/backups/core/repository.go @@ -245,3 +245,134 @@ func (r *BackupRepository) FindOldestByDatabaseExcludingInProgress( return backups, nil } + +func (r *BackupRepository) FindCompletedFullWalBackupByID( + databaseID uuid.UUID, + backupID uuid.UUID, +) (*Backup, error) { + var backup Backup + + err := storage. + GetDb(). + Where( + "database_id = ? AND id = ? AND pg_wal_backup_type = ? AND status = ?", + databaseID, + backupID, + PgWalBackupTypeFullBackup, + BackupStatusCompleted, + ). + First(&backup).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } + + return nil, err + } + + return &backup, nil +} + +func (r *BackupRepository) FindCompletedWalSegmentsAfter( + databaseID uuid.UUID, + afterSegmentName string, +) ([]*Backup, error) { + var backups []*Backup + + err := storage. + GetDb(). + Where( + "database_id = ? AND pg_wal_backup_type = ? AND pg_wal_segment_name >= ? AND status = ?", + databaseID, + PgWalBackupTypeWalSegment, + afterSegmentName, + BackupStatusCompleted, + ). + Order("pg_wal_segment_name ASC"). + Find(&backups).Error + if err != nil { + return nil, err + } + + return backups, nil +} + +func (r *BackupRepository) FindLastCompletedFullWalBackupByDatabaseID( + databaseID uuid.UUID, +) (*Backup, error) { + var backup Backup + + err := storage. + GetDb(). + Where( + "database_id = ? AND pg_wal_backup_type = ? AND status = ?", + databaseID, + PgWalBackupTypeFullBackup, + BackupStatusCompleted, + ). + Order("created_at DESC"). + First(&backup).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } + + return nil, err + } + + return &backup, nil +} + +func (r *BackupRepository) FindWalSegmentByName( + databaseID uuid.UUID, + segmentName string, +) (*Backup, error) { + var backup Backup + + err := storage. + GetDb(). + Where( + "database_id = ? AND pg_wal_backup_type = ? AND pg_wal_segment_name = ?", + databaseID, + PgWalBackupTypeWalSegment, + segmentName, + ). + First(&backup).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } + + return nil, err + } + + return &backup, nil +} + +func (r *BackupRepository) FindLastWalSegmentAfter( + databaseID uuid.UUID, + afterSegmentName string, +) (*Backup, error) { + var backup Backup + + err := storage. + GetDb(). + Where( + "database_id = ? AND pg_wal_backup_type = ? AND pg_wal_segment_name > ? AND status = ?", + databaseID, + PgWalBackupTypeWalSegment, + afterSegmentName, + BackupStatusCompleted, + ). + Order("pg_wal_segment_name DESC"). + First(&backup).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } + + return nil, err + } + + return &backup, nil +} diff --git a/backend/internal/features/backups/backups/dto.go b/backend/internal/features/backups/backups/dto.go deleted file mode 100644 index bb1b961..0000000 --- a/backend/internal/features/backups/backups/dto.go +++ /dev/null @@ -1,29 +0,0 @@ -package backups - -import ( - backups_core "databasus-backend/internal/features/backups/backups/core" - "databasus-backend/internal/features/backups/backups/encryption" - "io" -) - -type GetBackupsRequest struct { - DatabaseID string `form:"database_id" binding:"required"` - Limit int `form:"limit"` - Offset int `form:"offset"` -} - -type GetBackupsResponse struct { - Backups []*backups_core.Backup `json:"backups"` - Total int64 `json:"total"` - Limit int `json:"limit"` - Offset int `json:"offset"` -} - -type DecryptionReaderCloser struct { - *encryption.DecryptionReader - BaseReader io.ReadCloser -} - -func (r *DecryptionReaderCloser) Close() error { - return r.BaseReader.Close() -} diff --git a/backend/internal/features/backups/backups/dto/dto.go b/backend/internal/features/backups/backups/dto/dto.go new file mode 100644 index 0000000..f6813d3 --- /dev/null +++ b/backend/internal/features/backups/backups/dto/dto.go @@ -0,0 +1,78 @@ +package backups_dto + +import ( + backups_core "databasus-backend/internal/features/backups/backups/core" + "databasus-backend/internal/features/backups/backups/encryption" + "io" + "time" + + "github.com/google/uuid" +) + +type GetBackupsRequest struct { + DatabaseID string `form:"database_id" binding:"required"` + Limit int `form:"limit"` + Offset int `form:"offset"` +} + +type GetBackupsResponse struct { + Backups []*backups_core.Backup `json:"backups"` + Total int64 `json:"total"` + Limit int `json:"limit"` + Offset int `json:"offset"` +} + +type DecryptionReaderCloser struct { + *encryption.DecryptionReader + BaseReader io.ReadCloser +} + +func (r *DecryptionReaderCloser) Close() error { + return r.BaseReader.Close() +} + +type MakeBackupRequest struct { + DatabaseID uuid.UUID `json:"database_id" binding:"required"` +} + +type GetNextFullBackupTimeResponse struct { + NextFullBackupTime *time.Time `json:"nextFullBackupTime"` +} + +type ReportErrorRequest struct { + Error string `json:"error" binding:"required"` +} + +type UploadGapResponse struct { + Error string `json:"error"` + ExpectedSegmentName string `json:"expectedSegmentName"` + ReceivedSegmentName string `json:"receivedSegmentName"` +} + +type RestorePlanFullBackup struct { + BackupID uuid.UUID `json:"id"` + FullBackupWalStartSegment string `json:"fullBackupWalStartSegment"` + FullBackupWalStopSegment string `json:"fullBackupWalStopSegment"` + PgVersion string `json:"pgVersion"` + CreatedAt time.Time `json:"createdAt"` + SizeBytes int64 `json:"sizeBytes"` +} + +type RestorePlanWalSegment struct { + BackupID uuid.UUID `json:"backupId"` + SegmentName string `json:"segmentName"` + SizeBytes int64 `json:"sizeBytes"` +} + +type GetRestorePlanErrorResponse struct { + Error string `json:"error"` + Message string `json:"message"` + LastContiguousSegment string `json:"lastContiguousSegment,omitempty"` +} + +type GetRestorePlanResponse struct { + FullBackup RestorePlanFullBackup `json:"fullBackup"` + WalSegments []RestorePlanWalSegment `json:"walSegments"` + TotalSizeBytes int64 `json:"totalSizeBytes"` + LatestAvailableSegment string `json:"latestAvailableSegment"` +} diff --git a/backend/internal/features/backups/backups/encryption/setup.go b/backend/internal/features/backups/backups/encryption/setup.go new file mode 100644 index 0000000..764885e --- /dev/null +++ b/backend/internal/features/backups/backups/encryption/setup.go @@ -0,0 +1,45 @@ +package encryption + +import ( + "encoding/base64" + "fmt" + "io" + + "github.com/google/uuid" +) + +// EncryptionSetup holds the result of setting up encryption for a backup stream. +type EncryptionSetup struct { + Writer *EncryptionWriter + SaltBase64 string + NonceBase64 string +} + +// SetupEncryptionWriter generates salt/nonce, creates an EncryptionWriter, and +// returns the base64-encoded salt and nonce for storage on the backup record. +func SetupEncryptionWriter( + baseWriter io.Writer, + masterKey string, + backupID uuid.UUID, +) (*EncryptionSetup, error) { + salt, err := GenerateSalt() + if err != nil { + return nil, fmt.Errorf("failed to generate salt: %w", err) + } + + nonce, err := GenerateNonce() + if err != nil { + return nil, fmt.Errorf("failed to generate nonce: %w", err) + } + + encWriter, err := NewEncryptionWriter(baseWriter, masterKey, backupID, salt, nonce) + if err != nil { + return nil, fmt.Errorf("failed to create encryption writer: %w", err) + } + + return &EncryptionSetup{ + Writer: encWriter, + SaltBase64: base64.StdEncoding.EncodeToString(salt), + NonceBase64: base64.StdEncoding.EncodeToString(nonce), + }, nil +} diff --git a/backend/internal/features/backups/backups/di.go b/backend/internal/features/backups/backups/services/di.go similarity index 85% rename from backend/internal/features/backups/backups/di.go rename to backend/internal/features/backups/backups/services/di.go index d79dc6a..2613f39 100644 --- a/backend/internal/features/backups/backups/di.go +++ b/backend/internal/features/backups/backups/services/di.go @@ -1,9 +1,6 @@ -package backups +package backups_services import ( - "sync" - "sync/atomic" - audit_logs "databasus-backend/internal/features/audit_logs" "databasus-backend/internal/features/backups/backups/backuping" backups_core "databasus-backend/internal/features/backups/backups/core" @@ -18,16 +15,16 @@ import ( workspaces_services "databasus-backend/internal/features/workspaces/services" "databasus-backend/internal/util/encryption" "databasus-backend/internal/util/logger" + "sync" + "sync/atomic" ) -var backupRepository = &backups_core.BackupRepository{} - var taskCancelManager = task_cancellation.GetTaskCancelManager() var backupService = &BackupService{ databases.GetDatabaseService(), storages.GetStorageService(), - backupRepository, + backups_core.GetBackupRepository(), notifiers.GetNotifierService(), notifiers.GetNotifierService(), backups_config.GetBackupConfigService(), @@ -44,16 +41,21 @@ var backupService = &BackupService{ backuping.GetBackupCleaner(), } -var backupController = &BackupController{ - backupService: backupService, -} - func GetBackupService() *BackupService { return backupService } -func GetBackupController() *BackupController { - return backupController +var walService = &PostgreWalBackupService{ + backups_config.GetBackupConfigService(), + backups_core.GetBackupRepository(), + encryption.GetFieldEncryptor(), + encryption_secrets.GetSecretKeyService(), + logger.GetLogger(), + backupService, +} + +func GetWalService() *PostgreWalBackupService { + return walService } var ( diff --git a/backend/internal/features/backups/backups/services/postgres_wal_service.go b/backend/internal/features/backups/backups/services/postgres_wal_service.go new file mode 100644 index 0000000..b906254 --- /dev/null +++ b/backend/internal/features/backups/backups/services/postgres_wal_service.go @@ -0,0 +1,613 @@ +package backups_services + +import ( + "context" + "fmt" + "io" + "log/slog" + "time" + + backups_core "databasus-backend/internal/features/backups/backups/core" + backups_dto "databasus-backend/internal/features/backups/backups/dto" + backup_encryption "databasus-backend/internal/features/backups/backups/encryption" + backups_config "databasus-backend/internal/features/backups/config" + "databasus-backend/internal/features/databases" + "databasus-backend/internal/features/databases/databases/postgresql" + encryption_secrets "databasus-backend/internal/features/encryption/secrets" + util_encryption "databasus-backend/internal/util/encryption" + util_wal "databasus-backend/internal/util/wal" + + "github.com/google/uuid" +) + +// PostgreWalBackupService handles WAL segment and basebackup uploads from the databasus-cli agent. +type PostgreWalBackupService struct { + backupConfigService *backups_config.BackupConfigService + backupRepository *backups_core.BackupRepository + fieldEncryptor util_encryption.FieldEncryptor + secretKeyService *encryption_secrets.SecretKeyService + logger *slog.Logger + backupService *BackupService +} + +// UploadWal accepts a streaming WAL segment or basebackup upload from the agent. +// For WAL segments it validates the WAL chain before accepting. Returns an UploadGapResponse +// (409) when the chain is broken so the agent knows to trigger a full basebackup. +func (s *PostgreWalBackupService) UploadWal( + ctx context.Context, + database *databases.Database, + uploadType backups_core.PgWalUploadType, + walSegmentName string, + fullBackupWalStartSegment string, + fullBackupWalStopSegment string, + walSegmentSizeBytes int64, + body io.Reader, +) (*backups_dto.UploadGapResponse, error) { + if err := s.validateWalBackupType(database); err != nil { + return nil, err + } + + if uploadType == backups_core.PgWalUploadTypeBasebackup { + if fullBackupWalStartSegment == "" || fullBackupWalStopSegment == "" { + return nil, fmt.Errorf( + "fullBackupWalStartSegment and fullBackupWalStopSegment are required for basebackup uploads", + ) + } + } + + backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID) + if err != nil { + return nil, fmt.Errorf("failed to get backup config: %w", err) + } + + if backupConfig.Storage == nil { + return nil, fmt.Errorf("no storage configured for database %s", database.ID) + } + + if uploadType == backups_core.PgWalUploadTypeWal { + // Idempotency: check before chain validation so a successful re-upload is + // not misidentified as a gap. + existing, err := s.backupRepository.FindWalSegmentByName(database.ID, walSegmentName) + if err != nil { + return nil, fmt.Errorf("failed to check for duplicate WAL segment: %w", err) + } + + if existing != nil { + return nil, nil + } + + gapResp, err := s.validateWalChain(database.ID, walSegmentName, walSegmentSizeBytes) + if err != nil { + return nil, err + } + + if gapResp != nil { + return gapResp, nil + } + } + + backup := s.createBackupRecord( + database.ID, + backupConfig.Storage.ID, + uploadType, + database.Name, + walSegmentName, + fullBackupWalStartSegment, + fullBackupWalStopSegment, + backupConfig.Encryption, + ) + + if err := s.backupRepository.Save(backup); err != nil { + return nil, fmt.Errorf("failed to create backup record: %w", err) + } + + sizeBytes, streamErr := s.streamToStorage(ctx, backup, backupConfig, body) + if streamErr != nil { + errMsg := streamErr.Error() + s.markFailed(backup, errMsg) + + return nil, fmt.Errorf("upload failed: %w", streamErr) + } + + s.markCompleted(backup, sizeBytes) + + return nil, nil +} + +func (s *PostgreWalBackupService) GetRestorePlan( + database *databases.Database, + backupID *uuid.UUID, +) (*backups_dto.GetRestorePlanResponse, *backups_dto.GetRestorePlanErrorResponse, error) { + if err := s.validateWalBackupType(database); err != nil { + return nil, nil, err + } + + fullBackup, err := s.resolveFullBackup(database.ID, backupID) + if err != nil { + return nil, nil, err + } + + if fullBackup == nil { + msg := "no full backups available for this database" + if backupID != nil { + msg = fmt.Sprintf("full backup %s not found or not completed", backupID) + } + + return nil, &backups_dto.GetRestorePlanErrorResponse{ + Error: "no_backups", + Message: msg, + }, nil + } + + startSegment := "" + if fullBackup.PgFullBackupWalStartSegmentName != nil { + startSegment = *fullBackup.PgFullBackupWalStartSegmentName + } + + walSegments, err := s.backupRepository.FindCompletedWalSegmentsAfter(database.ID, startSegment) + if err != nil { + return nil, nil, fmt.Errorf("failed to query WAL segments: %w", err) + } + + chainErr := s.validateRestoreWalChain(fullBackup, walSegments) + if chainErr != nil { + return nil, chainErr, nil + } + + fullBackupSizeBytes := int64(fullBackup.BackupSizeMb * 1024 * 1024) + + pgVersion := "" + if fullBackup.PgVersion != nil { + pgVersion = *fullBackup.PgVersion + } + + stopSegment := "" + if fullBackup.PgFullBackupWalStopSegmentName != nil { + stopSegment = *fullBackup.PgFullBackupWalStopSegmentName + } + + response := &backups_dto.GetRestorePlanResponse{ + FullBackup: backups_dto.RestorePlanFullBackup{ + BackupID: fullBackup.ID, + FullBackupWalStartSegment: startSegment, + FullBackupWalStopSegment: stopSegment, + PgVersion: pgVersion, + CreatedAt: fullBackup.CreatedAt, + SizeBytes: fullBackupSizeBytes, + }, + TotalSizeBytes: fullBackupSizeBytes, + } + + for _, seg := range walSegments { + segName := "" + if seg.PgWalSegmentName != nil { + segName = *seg.PgWalSegmentName + } + + segSizeBytes := int64(seg.BackupSizeMb * 1024 * 1024) + + response.WalSegments = append(response.WalSegments, backups_dto.RestorePlanWalSegment{ + BackupID: seg.ID, + SegmentName: segName, + SizeBytes: segSizeBytes, + }) + + response.TotalSizeBytes += segSizeBytes + response.LatestAvailableSegment = segName + } + + return response, nil, nil +} + +// DownloadBackupFile returns a reader for a backup file belonging to the given database. +// Decryption is handled transparently if the backup is encrypted. +func (s *PostgreWalBackupService) DownloadBackupFile( + database *databases.Database, + backupID uuid.UUID, +) (io.ReadCloser, error) { + if err := s.validateWalBackupType(database); err != nil { + return nil, err + } + + backup, err := s.backupRepository.FindByID(backupID) + if err != nil { + return nil, fmt.Errorf("backup not found: %w", err) + } + + if backup.DatabaseID != database.ID { + return nil, fmt.Errorf("backup does not belong to this database") + } + + if backup.Status != backups_core.BackupStatusCompleted { + return nil, fmt.Errorf("backup is not completed") + } + + return s.backupService.GetBackupReader(backupID) +} + +func (s *PostgreWalBackupService) validateWalChain( + databaseID uuid.UUID, + incomingSegment string, + walSegmentSizeBytes int64, +) (*backups_dto.UploadGapResponse, error) { + fullBackup, err := s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(databaseID) + if err != nil { + return nil, fmt.Errorf("failed to query full backup: %w", err) + } + + // No full backup exists yet: cannot accept WAL segments without a chain anchor. + if fullBackup == nil || fullBackup.PgFullBackupWalStopSegmentName == nil { + return &backups_dto.UploadGapResponse{ + Error: "no_full_backup", + ExpectedSegmentName: "", + ReceivedSegmentName: incomingSegment, + }, nil + } + + stopSegment := *fullBackup.PgFullBackupWalStopSegmentName + + lastWal, err := s.backupRepository.FindLastWalSegmentAfter(databaseID, stopSegment) + if err != nil { + return nil, fmt.Errorf("failed to query last WAL segment: %w", err) + } + + walCalculator := util_wal.NewWalCalculator(walSegmentSizeBytes) + + var chainTail string + if lastWal != nil && lastWal.PgWalSegmentName != nil { + chainTail = *lastWal.PgWalSegmentName + } else { + chainTail = stopSegment + } + + expectedNext, err := walCalculator.NextSegment(chainTail) + if err != nil { + return nil, fmt.Errorf("WAL arithmetic failed for %q: %w", chainTail, err) + } + + if incomingSegment != expectedNext { + return &backups_dto.UploadGapResponse{ + Error: "gap_detected", + ExpectedSegmentName: expectedNext, + ReceivedSegmentName: incomingSegment, + }, nil + } + + return nil, nil +} + +func (s *PostgreWalBackupService) createBackupRecord( + databaseID uuid.UUID, + storageID uuid.UUID, + uploadType backups_core.PgWalUploadType, + dbName string, + walSegmentName string, + fullBackupWalStartSegment string, + fullBackupWalStopSegment string, + encryption backups_config.BackupEncryption, +) *backups_core.Backup { + now := time.Now().UTC() + + backup := &backups_core.Backup{ + ID: uuid.New(), + DatabaseID: databaseID, + StorageID: storageID, + Status: backups_core.BackupStatusInProgress, + Encryption: encryption, + CreatedAt: now, + } + + backup.GenerateFilename(dbName) + + if uploadType == backups_core.PgWalUploadTypeBasebackup { + walBackupType := backups_core.PgWalBackupTypeFullBackup + backup.PgWalBackupType = &walBackupType + + if fullBackupWalStartSegment != "" { + backup.PgFullBackupWalStartSegmentName = &fullBackupWalStartSegment + } + + if fullBackupWalStopSegment != "" { + backup.PgFullBackupWalStopSegmentName = &fullBackupWalStopSegment + } + } else { + walBackupType := backups_core.PgWalBackupTypeWalSegment + backup.PgWalBackupType = &walBackupType + backup.PgWalSegmentName = &walSegmentName + } + + return backup +} + +func (s *PostgreWalBackupService) streamToStorage( + ctx context.Context, + backup *backups_core.Backup, + backupConfig *backups_config.BackupConfig, + body io.Reader, +) (int64, error) { + if backupConfig.Encryption == backups_config.BackupEncryptionEncrypted { + return s.streamEncrypted(ctx, backup, backupConfig, body, backup.FileName) + } + + return s.streamDirect(ctx, backupConfig, body, backup.FileName) +} + +func (s *PostgreWalBackupService) streamDirect( + ctx context.Context, + backupConfig *backups_config.BackupConfig, + body io.Reader, + fileName string, +) (int64, error) { + cr := &countingReader{r: body} + + if err := backupConfig.Storage.SaveFile(ctx, s.fieldEncryptor, s.logger, fileName, cr); err != nil { + return 0, err + } + + return cr.n, nil +} + +func (s *PostgreWalBackupService) streamEncrypted( + ctx context.Context, + backup *backups_core.Backup, + backupConfig *backups_config.BackupConfig, + body io.Reader, + fileName string, +) (int64, error) { + masterKey, err := s.secretKeyService.GetSecretKey() + if err != nil { + return 0, fmt.Errorf("failed to get master encryption key: %w", err) + } + + pipeReader, pipeWriter := io.Pipe() + + encryptionSetup, err := backup_encryption.SetupEncryptionWriter( + pipeWriter, + masterKey, + backup.ID, + ) + if err != nil { + _ = pipeWriter.Close() + return 0, err + } + + copyErrCh := make(chan error, 1) + go func() { + _, copyErr := io.Copy(encryptionSetup.Writer, body) + if copyErr != nil { + _ = encryptionSetup.Writer.Close() + _ = pipeWriter.CloseWithError(copyErr) + copyErrCh <- copyErr + return + } + + if closeErr := encryptionSetup.Writer.Close(); closeErr != nil { + _ = pipeWriter.CloseWithError(closeErr) + copyErrCh <- closeErr + return + } + + copyErrCh <- pipeWriter.Close() + }() + + cr := &countingReader{r: pipeReader} + saveErr := backupConfig.Storage.SaveFile(ctx, s.fieldEncryptor, s.logger, fileName, cr) + copyErr := <-copyErrCh + + if copyErr != nil { + return 0, copyErr + } + + if saveErr != nil { + return 0, saveErr + } + + backup.EncryptionSalt = &encryptionSetup.SaltBase64 + backup.EncryptionIV = &encryptionSetup.NonceBase64 + + return cr.n, nil +} + +func (s *PostgreWalBackupService) markCompleted(backup *backups_core.Backup, sizeBytes int64) { + backup.Status = backups_core.BackupStatusCompleted + backup.BackupSizeMb = float64(sizeBytes) / (1024 * 1024) + + if err := s.backupRepository.Save(backup); err != nil { + s.logger.Error( + "failed to mark WAL backup as completed", + "backupId", + backup.ID, + "error", + err, + ) + } +} + +func (s *PostgreWalBackupService) markFailed(backup *backups_core.Backup, errMsg string) { + backup.Status = backups_core.BackupStatusFailed + backup.FailMessage = &errMsg + + if err := s.backupRepository.Save(backup); err != nil { + s.logger.Error("failed to mark WAL backup as failed", "backupId", backup.ID, "error", err) + } +} + +func (s *PostgreWalBackupService) GetNextFullBackupTime( + database *databases.Database, +) (*backups_dto.GetNextFullBackupTimeResponse, error) { + if err := s.validateWalBackupType(database); err != nil { + return nil, err + } + + backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID) + if err != nil { + return nil, fmt.Errorf("failed to get backup config: %w", err) + } + + if backupConfig.BackupInterval == nil { + return nil, fmt.Errorf("no backup interval configured for database %s", database.ID) + } + + lastFullBackup, err := s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID( + database.ID, + ) + if err != nil { + return nil, fmt.Errorf("failed to query last full backup: %w", err) + } + + var lastBackupTime *time.Time + if lastFullBackup != nil { + lastBackupTime = &lastFullBackup.CreatedAt + } + + now := time.Now().UTC() + nextTime := backupConfig.BackupInterval.NextTriggerTime(now, lastBackupTime) + + return &backups_dto.GetNextFullBackupTimeResponse{ + NextFullBackupTime: nextTime, + }, nil +} + +// ReportError creates a FAILED backup record with the agent's error message. +func (s *PostgreWalBackupService) ReportError( + database *databases.Database, + errorMsg string, +) error { + if err := s.validateWalBackupType(database); err != nil { + return err + } + + backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID) + if err != nil { + return fmt.Errorf("failed to get backup config: %w", err) + } + + if backupConfig.Storage == nil { + return fmt.Errorf("no storage configured for database %s", database.ID) + } + + now := time.Now().UTC() + backup := &backups_core.Backup{ + ID: uuid.New(), + DatabaseID: database.ID, + StorageID: backupConfig.Storage.ID, + Status: backups_core.BackupStatusFailed, + FailMessage: &errorMsg, + Encryption: backupConfig.Encryption, + CreatedAt: now, + } + + backup.GenerateFilename(database.Name) + + if err := s.backupRepository.Save(backup); err != nil { + return fmt.Errorf("failed to save error backup record: %w", err) + } + + return nil +} + +func (s *PostgreWalBackupService) resolveFullBackup( + databaseID uuid.UUID, + backupID *uuid.UUID, +) (*backups_core.Backup, error) { + if backupID != nil { + return s.backupRepository.FindCompletedFullWalBackupByID(databaseID, *backupID) + } + + return s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(databaseID) +} + +func (s *PostgreWalBackupService) validateRestoreWalChain( + fullBackup *backups_core.Backup, + walSegments []*backups_core.Backup, +) *backups_dto.GetRestorePlanErrorResponse { + if len(walSegments) == 0 { + return nil + } + + stopSegment := "" + if fullBackup.PgFullBackupWalStopSegmentName != nil { + stopSegment = *fullBackup.PgFullBackupWalStopSegmentName + } + + walCalculator := util_wal.NewWalCalculator(0) + expectedNext, err := walCalculator.NextSegment(stopSegment) + if err != nil { + return nil + } + + for _, seg := range walSegments { + segName := "" + if seg.PgWalSegmentName != nil { + segName = *seg.PgWalSegmentName + } + + cmp, cmpErr := walCalculator.Compare(segName, stopSegment) + if cmpErr != nil { + return nil + } + + // Skip segments that are <= stopSegment (they are part of the basebackup range) + if cmp <= 0 { + continue + } + + if segName != expectedNext { + lastContiguous := stopSegment + // Walk back to find the segment before the gap + for _, prev := range walSegments { + prevName := "" + if prev.PgWalSegmentName != nil { + prevName = *prev.PgWalSegmentName + } + + prevCmp, _ := walCalculator.Compare(prevName, stopSegment) + if prevCmp <= 0 { + continue + } + + if prevName == segName { + break + } + + lastContiguous = prevName + } + + return &backups_dto.GetRestorePlanErrorResponse{ + Error: "wal_chain_broken", + Message: fmt.Sprintf( + "WAL chain has a gap after segment %s. Recovery is only possible up to this segment.", + lastContiguous, + ), + LastContiguousSegment: lastContiguous, + } + } + + expectedNext, err = walCalculator.NextSegment(segName) + if err != nil { + return nil + } + } + + return nil +} + +func (s *PostgreWalBackupService) validateWalBackupType(database *databases.Database) error { + if database.Postgresql == nil || + database.Postgresql.BackupType != postgresql.PostgresBackupTypeWalV1 { + return fmt.Errorf("database %s is not configured for WAL backups", database.ID) + } + return nil +} + +type countingReader struct { + r io.Reader + n int64 +} + +func (cr *countingReader) Read(p []byte) (n int, err error) { + n, err = cr.r.Read(p) + cr.n += int64(n) + + return +} diff --git a/backend/internal/features/backups/backups/service.go b/backend/internal/features/backups/backups/services/service.go similarity index 96% rename from backend/internal/features/backups/backups/service.go rename to backend/internal/features/backups/backups/services/service.go index cfb46f5..1cd738e 100644 --- a/backend/internal/features/backups/backups/service.go +++ b/backend/internal/features/backups/backups/services/service.go @@ -1,4 +1,4 @@ -package backups +package backups_services import ( "encoding/base64" @@ -11,6 +11,7 @@ import ( "databasus-backend/internal/features/backups/backups/backuping" backups_core "databasus-backend/internal/features/backups/backups/core" backups_download "databasus-backend/internal/features/backups/backups/download" + backups_dto "databasus-backend/internal/features/backups/backups/dto" "databasus-backend/internal/features/backups/backups/encryption" backups_config "databasus-backend/internal/features/backups/config" "databasus-backend/internal/features/databases" @@ -108,7 +109,7 @@ func (s *BackupService) GetBackups( user *users_models.User, databaseID uuid.UUID, limit, offset int, -) (*GetBackupsResponse, error) { +) (*backups_dto.GetBackupsResponse, error) { database, err := s.databaseService.GetDatabaseByID(databaseID) if err != nil { return nil, err @@ -143,7 +144,7 @@ func (s *BackupService) GetBackups( return nil, err } - return &GetBackupsResponse{ + return &backups_dto.GetBackupsResponse{ Backups: backups, Total: total, Limit: limit, @@ -274,7 +275,7 @@ func (s *BackupService) GetBackupFile( database.WorkspaceID, ) - reader, err := s.getBackupReader(backupID) + reader, err := s.GetBackupReader(backupID) if err != nil { return nil, nil, nil, err } @@ -282,39 +283,9 @@ func (s *BackupService) GetBackupFile( return reader, backup, database, nil } -func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error { - dbBackupsInProgress, err := s.backupRepository.FindByDatabaseIdAndStatus( - databaseID, - backups_core.BackupStatusInProgress, - ) - if err != nil { - return err - } - - if len(dbBackupsInProgress) > 0 { - return errors.New("backup is in progress, storage cannot be removed") - } - - dbBackups, err := s.backupRepository.FindByDatabaseID( - databaseID, - ) - if err != nil { - return err - } - - for _, dbBackup := range dbBackups { - err := s.backupCleaner.DeleteBackup(dbBackup) - if err != nil { - return err - } - } - - return nil -} - -// GetBackupReader returns a reader for the backup file -// If encrypted, wraps with DecryptionReader -func (s *BackupService) getBackupReader(backupID uuid.UUID) (io.ReadCloser, error) { +// GetBackupReader returns a reader for the backup file. +// If encrypted, wraps with DecryptionReader. +func (s *BackupService) GetBackupReader(backupID uuid.UUID) (io.ReadCloser, error) { backup, err := s.backupRepository.FindByID(backupID) if err != nil { return nil, fmt.Errorf("failed to find backup: %w", err) @@ -394,7 +365,7 @@ func (s *BackupService) getBackupReader(backupID uuid.UUID) (io.ReadCloser, erro s.logger.Info("Returning encrypted backup with decryption", "backupId", backupID) - return &DecryptionReaderCloser{ + return &backups_dto.DecryptionReaderCloser{ DecryptionReader: decryptionReader, BaseReader: fileReader, }, nil @@ -465,7 +436,7 @@ func (s *BackupService) GetBackupFileWithoutAuth( return nil, nil, nil, err } - reader, err := s.getBackupReader(backupID) + reader, err := s.GetBackupReader(backupID) if err != nil { return nil, nil, nil, err } @@ -501,6 +472,36 @@ func (s *BackupService) UnregisterDownload(userID uuid.UUID) { s.downloadTokenService.UnregisterDownload(userID) } +func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error { + dbBackupsInProgress, err := s.backupRepository.FindByDatabaseIdAndStatus( + databaseID, + backups_core.BackupStatusInProgress, + ) + if err != nil { + return err + } + + if len(dbBackupsInProgress) > 0 { + return errors.New("backup is in progress, storage cannot be removed") + } + + dbBackups, err := s.backupRepository.FindByDatabaseID( + databaseID, + ) + if err != nil { + return err + } + + for _, dbBackup := range dbBackups { + err := s.backupCleaner.DeleteBackup(dbBackup) + if err != nil { + return err + } + } + + return nil +} + func (s *BackupService) generateBackupFilename( backup *backups_core.Backup, database *databases.Database, diff --git a/backend/internal/features/backups/backups/usecases/mariadb/create_backup_uc.go b/backend/internal/features/backups/backups/usecases/mariadb/create_backup_uc.go index 0914ac4..4c66dc1 100644 --- a/backend/internal/features/backups/backups/usecases/mariadb/create_backup_uc.go +++ b/backend/internal/features/backups/backups/usecases/mariadb/create_backup_uc.go @@ -2,7 +2,6 @@ package usecases_mariadb import ( "context" - "encoding/base64" "errors" "fmt" "io" @@ -437,40 +436,22 @@ func (uc *CreateMariadbBackupUsecase) setupBackupEncryption( return storageWriter, nil, metadata, nil } - salt, err := backup_encryption.GenerateSalt() - if err != nil { - return nil, nil, metadata, fmt.Errorf("failed to generate salt: %w", err) - } - - nonce, err := backup_encryption.GenerateNonce() - if err != nil { - return nil, nil, metadata, fmt.Errorf("failed to generate nonce: %w", err) - } - masterKey, err := uc.secretKeyService.GetSecretKey() if err != nil { return nil, nil, metadata, fmt.Errorf("failed to get master key: %w", err) } - encWriter, err := backup_encryption.NewEncryptionWriter( - storageWriter, - masterKey, - backupID, - salt, - nonce, - ) + encSetup, err := backup_encryption.SetupEncryptionWriter(storageWriter, masterKey, backupID) if err != nil { - return nil, nil, metadata, fmt.Errorf("failed to create encrypting writer: %w", err) + return nil, nil, metadata, err } - saltBase64 := base64.StdEncoding.EncodeToString(salt) - nonceBase64 := base64.StdEncoding.EncodeToString(nonce) - metadata.EncryptionSalt = &saltBase64 - metadata.EncryptionIV = &nonceBase64 + metadata.EncryptionSalt = &encSetup.SaltBase64 + metadata.EncryptionIV = &encSetup.NonceBase64 metadata.Encryption = backups_config.BackupEncryptionEncrypted uc.logger.Info("Encryption enabled for backup", "backupId", backupID) - return encWriter, encWriter, metadata, nil + return encSetup.Writer, encSetup.Writer, metadata, nil } func (uc *CreateMariadbBackupUsecase) cleanupOnCancellation( diff --git a/backend/internal/features/backups/backups/usecases/mongodb/create_backup_uc.go b/backend/internal/features/backups/backups/usecases/mongodb/create_backup_uc.go index a2787af..17ec0bd 100644 --- a/backend/internal/features/backups/backups/usecases/mongodb/create_backup_uc.go +++ b/backend/internal/features/backups/backups/usecases/mongodb/create_backup_uc.go @@ -2,7 +2,6 @@ package usecases_mongodb import ( "context" - "encoding/base64" "errors" "fmt" "io" @@ -277,41 +276,21 @@ func (uc *CreateMongodbBackupUsecase) setupBackupEncryption( return storageWriter, nil, backupMetadata, nil } - salt, err := backup_encryption.GenerateSalt() - if err != nil { - return nil, nil, backupMetadata, fmt.Errorf("failed to generate salt: %w", err) - } - - nonce, err := backup_encryption.GenerateNonce() - if err != nil { - return nil, nil, backupMetadata, fmt.Errorf("failed to generate nonce: %w", err) - } - masterKey, err := uc.secretKeyService.GetSecretKey() if err != nil { return nil, nil, backupMetadata, fmt.Errorf("failed to get master key: %w", err) } - encryptionWriter, err := backup_encryption.NewEncryptionWriter( - storageWriter, - masterKey, - backupID, - salt, - nonce, - ) + encSetup, err := backup_encryption.SetupEncryptionWriter(storageWriter, masterKey, backupID) if err != nil { - return nil, nil, backupMetadata, fmt.Errorf("failed to create encryption writer: %w", err) + return nil, nil, backupMetadata, err } - saltBase64 := base64.StdEncoding.EncodeToString(salt) - nonceBase64 := base64.StdEncoding.EncodeToString(nonce) - - backupMetadata.BackupID = backupID backupMetadata.Encryption = backups_config.BackupEncryptionEncrypted - backupMetadata.EncryptionSalt = &saltBase64 - backupMetadata.EncryptionIV = &nonceBase64 + backupMetadata.EncryptionSalt = &encSetup.SaltBase64 + backupMetadata.EncryptionIV = &encSetup.NonceBase64 - return encryptionWriter, encryptionWriter, backupMetadata, nil + return encSetup.Writer, encSetup.Writer, backupMetadata, nil } func (uc *CreateMongodbBackupUsecase) copyWithShutdownCheck( diff --git a/backend/internal/features/backups/backups/usecases/mysql/create_backup_uc.go b/backend/internal/features/backups/backups/usecases/mysql/create_backup_uc.go index 7350b69..8253e92 100644 --- a/backend/internal/features/backups/backups/usecases/mysql/create_backup_uc.go +++ b/backend/internal/features/backups/backups/usecases/mysql/create_backup_uc.go @@ -2,7 +2,6 @@ package usecases_mysql import ( "context" - "encoding/base64" "errors" "fmt" "io" @@ -448,40 +447,22 @@ func (uc *CreateMysqlBackupUsecase) setupBackupEncryption( return storageWriter, nil, metadata, nil } - salt, err := backup_encryption.GenerateSalt() - if err != nil { - return nil, nil, metadata, fmt.Errorf("failed to generate salt: %w", err) - } - - nonce, err := backup_encryption.GenerateNonce() - if err != nil { - return nil, nil, metadata, fmt.Errorf("failed to generate nonce: %w", err) - } - masterKey, err := uc.secretKeyService.GetSecretKey() if err != nil { return nil, nil, metadata, fmt.Errorf("failed to get master key: %w", err) } - encWriter, err := backup_encryption.NewEncryptionWriter( - storageWriter, - masterKey, - backupID, - salt, - nonce, - ) + encSetup, err := backup_encryption.SetupEncryptionWriter(storageWriter, masterKey, backupID) if err != nil { - return nil, nil, metadata, fmt.Errorf("failed to create encrypting writer: %w", err) + return nil, nil, metadata, err } - saltBase64 := base64.StdEncoding.EncodeToString(salt) - nonceBase64 := base64.StdEncoding.EncodeToString(nonce) - metadata.EncryptionSalt = &saltBase64 - metadata.EncryptionIV = &nonceBase64 + metadata.EncryptionSalt = &encSetup.SaltBase64 + metadata.EncryptionIV = &encSetup.NonceBase64 metadata.Encryption = backups_config.BackupEncryptionEncrypted uc.logger.Info("Encryption enabled for backup", "backupId", backupID) - return encWriter, encWriter, metadata, nil + return encSetup.Writer, encSetup.Writer, metadata, nil } func (uc *CreateMysqlBackupUsecase) cleanupOnCancellation( diff --git a/backend/internal/features/backups/backups/usecases/postgresql/create_backup_uc.go b/backend/internal/features/backups/backups/usecases/postgresql/create_backup_uc.go index be1dc7a..b656478 100644 --- a/backend/internal/features/backups/backups/usecases/postgresql/create_backup_uc.go +++ b/backend/internal/features/backups/backups/usecases/postgresql/create_backup_uc.go @@ -2,7 +2,6 @@ package usecases_postgresql import ( "context" - "encoding/base64" "errors" "fmt" "io" @@ -492,40 +491,22 @@ func (uc *CreatePostgresqlBackupUsecase) setupBackupEncryption( return storageWriter, nil, metadata, nil } - salt, err := backup_encryption.GenerateSalt() - if err != nil { - return nil, nil, metadata, fmt.Errorf("failed to generate salt: %w", err) - } - - nonce, err := backup_encryption.GenerateNonce() - if err != nil { - return nil, nil, metadata, fmt.Errorf("failed to generate nonce: %w", err) - } - masterKey, err := uc.secretKeyService.GetSecretKey() if err != nil { return nil, nil, metadata, fmt.Errorf("failed to get master key: %w", err) } - encWriter, err := backup_encryption.NewEncryptionWriter( - storageWriter, - masterKey, - backupID, - salt, - nonce, - ) + encSetup, err := backup_encryption.SetupEncryptionWriter(storageWriter, masterKey, backupID) if err != nil { - return nil, nil, metadata, fmt.Errorf("failed to create encrypting writer: %w", err) + return nil, nil, metadata, err } - saltBase64 := base64.StdEncoding.EncodeToString(salt) - nonceBase64 := base64.StdEncoding.EncodeToString(nonce) - metadata.EncryptionSalt = &saltBase64 - metadata.EncryptionIV = &nonceBase64 + metadata.EncryptionSalt = &encSetup.SaltBase64 + metadata.EncryptionIV = &encSetup.NonceBase64 metadata.Encryption = backups_config.BackupEncryptionEncrypted uc.logger.Info("Encryption enabled for backup", "backupId", backupID) - return encWriter, encWriter, metadata, nil + return encSetup.Writer, encSetup.Writer, metadata, nil } func (uc *CreatePostgresqlBackupUsecase) cleanupOnCancellation( diff --git a/backend/internal/features/databases/controller.go b/backend/internal/features/databases/controller.go index d200809..ec11bf8 100644 --- a/backend/internal/features/databases/controller.go +++ b/backend/internal/features/databases/controller.go @@ -29,6 +29,11 @@ func (c *DatabaseController) RegisterRoutes(router *gin.RouterGroup) { router.GET("/databases/notifier/:id/databases-count", c.CountDatabasesByNotifier) router.POST("/databases/is-readonly", c.IsUserReadOnly) router.POST("/databases/create-readonly-user", c.CreateReadOnlyUser) + router.POST("/databases/:id/regenerate-token", c.RegenerateAgentToken) +} + +func (c *DatabaseController) RegisterPublicRoutes(router *gin.RouterGroup) { + router.POST("/databases/verify-token", c.VerifyAgentToken) } // CreateDatabase @@ -438,3 +443,61 @@ func (c *DatabaseController) CreateReadOnlyUser(ctx *gin.Context) { Password: password, }) } + +// RegenerateAgentToken +// @Summary Regenerate agent token for a database +// @Description Generate a new agent token for the database. The token is returned once and stored as a hash. +// @Tags databases +// @Produce json +// @Security BearerAuth +// @Param id path string true "Database ID" +// @Success 200 {object} map[string]string +// @Failure 400 {object} map[string]string +// @Failure 401 {object} map[string]string +// @Router /databases/{id}/regenerate-token [post] +func (c *DatabaseController) RegenerateAgentToken(ctx *gin.Context) { + user, ok := users_middleware.GetUserFromContext(ctx) + if !ok { + ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"}) + return + } + + id, err := uuid.Parse(ctx.Param("id")) + if err != nil { + ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"}) + return + } + + token, err := c.databaseService.RegenerateAgentToken(user, id) + if err != nil { + ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + ctx.JSON(http.StatusOK, gin.H{"token": token}) +} + +// VerifyAgentToken +// @Summary Verify agent token +// @Description Verify that a given agent token is valid for any database +// @Tags databases +// @Accept json +// @Produce json +// @Param request body VerifyAgentTokenRequest true "Token to verify" +// @Success 200 {object} map[string]string +// @Failure 401 {object} map[string]string +// @Router /databases/verify-token [post] +func (c *DatabaseController) VerifyAgentToken(ctx *gin.Context) { + var request VerifyAgentTokenRequest + if err := ctx.ShouldBindJSON(&request); err != nil { + ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if err := c.databaseService.VerifyAgentToken(request.Token); err != nil { + ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"}) + return + } + + ctx.JSON(http.StatusOK, gin.H{"message": "token is valid"}) +} diff --git a/backend/internal/features/databases/controller_test.go b/backend/internal/features/databases/controller_test.go index b17dd1d..214cba3 100644 --- a/backend/internal/features/databases/controller_test.go +++ b/backend/internal/features/databases/controller_test.go @@ -13,10 +13,13 @@ import ( "github.com/stretchr/testify/assert" "databasus-backend/internal/config" + "databasus-backend/internal/features/audit_logs" "databasus-backend/internal/features/databases/databases/mariadb" "databasus-backend/internal/features/databases/databases/mongodb" "databasus-backend/internal/features/databases/databases/postgresql" users_enums "databasus-backend/internal/features/users/enums" + users_middleware "databasus-backend/internal/features/users/middleware" + users_services "databasus-backend/internal/features/users/services" users_testing "databasus-backend/internal/features/users/testing" workspaces_controllers "databasus-backend/internal/features/workspaces/controllers" workspaces_testing "databasus-backend/internal/features/workspaces/testing" @@ -144,6 +147,66 @@ func Test_CreateDatabase_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testin assert.Contains(t, string(testResp.Body), "insufficient permissions") } +func Test_CreateDatabase_WalV1Type_NoConnectionFieldsRequired(t *testing.T) { + router := createTestRouter() + owner := users_testing.CreateTestUser(users_enums.UserRoleMember) + workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router) + defer workspaces_testing.RemoveTestWorkspace(workspace, router) + + request := Database{ + Name: "Test WAL Database", + WorkspaceID: &workspace.ID, + Type: DatabaseTypePostgres, + Postgresql: &postgresql.PostgresqlDatabase{ + BackupType: postgresql.PostgresBackupTypeWalV1, + CpuCount: 1, + }, + } + + var response Database + test_utils.MakePostRequestAndUnmarshal( + t, + router, + "/api/v1/databases/create", + "Bearer "+owner.Token, + request, + http.StatusCreated, + &response, + ) + defer RemoveTestDatabase(&response) + + assert.Equal(t, "Test WAL Database", response.Name) + assert.NotEqual(t, uuid.Nil, response.ID) +} + +func Test_CreateDatabase_PgDumpType_ConnectionFieldsRequired(t *testing.T) { + router := createTestRouter() + owner := users_testing.CreateTestUser(users_enums.UserRoleMember) + workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router) + defer workspaces_testing.RemoveTestWorkspace(workspace, router) + + request := Database{ + Name: "Test PG_DUMP Database", + WorkspaceID: &workspace.ID, + Type: DatabaseTypePostgres, + Postgresql: &postgresql.PostgresqlDatabase{ + BackupType: postgresql.PostgresBackupTypePgDump, + CpuCount: 1, + }, + } + + testResp := test_utils.MakePostRequest( + t, + router, + "/api/v1/databases/create", + "Bearer "+owner.Token, + request, + http.StatusBadRequest, + ) + + assert.Contains(t, string(testResp.Body), "host is required") +} + func Test_UpdateDatabase_PermissionsEnforced(t *testing.T) { tests := []struct { name string @@ -256,6 +319,52 @@ func Test_UpdateDatabase_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testin assert.Contains(t, string(testResp.Body), "insufficient permissions") } +func Test_UpdateDatabase_WhenDatabaseTypeChanged_ReturnsBadRequest(t *testing.T) { + router := createTestRouter() + owner := users_testing.CreateTestUser(users_enums.UserRoleMember) + workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router) + defer workspaces_testing.RemoveTestWorkspace(workspace, router) + + database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router) + defer RemoveTestDatabase(database) + + database.Type = DatabaseTypeMysql + + testResp := test_utils.MakePostRequest( + t, + router, + "/api/v1/databases/update", + "Bearer "+owner.Token, + database, + http.StatusBadRequest, + ) + + assert.Contains(t, string(testResp.Body), "database type cannot be changed") +} + +func Test_UpdateDatabase_WhenBackupTypeChanged_ReturnsBadRequest(t *testing.T) { + router := createTestRouter() + owner := users_testing.CreateTestUser(users_enums.UserRoleMember) + workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router) + defer workspaces_testing.RemoveTestWorkspace(workspace, router) + + database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router) + defer RemoveTestDatabase(database) + + database.Postgresql.BackupType = postgresql.PostgresBackupTypeWalV1 + + testResp := test_utils.MakePostRequest( + t, + router, + "/api/v1/databases/update", + "Bearer "+owner.Token, + database, + http.StatusBadRequest, + ) + + assert.Contains(t, string(testResp.Body), "backup type cannot be changed") +} + func Test_DeleteDatabase_PermissionsEnforced(t *testing.T) { tests := []struct { name string @@ -1050,6 +1159,87 @@ func Test_TestConnection_PermissionsEnforced(t *testing.T) { } } +func Test_RegenerateAgentToken_ReturnsToken(t *testing.T) { + router := createTestRouter() + owner := users_testing.CreateTestUser(users_enums.UserRoleMember) + workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router) + defer workspaces_testing.RemoveTestWorkspace(workspace, router) + + database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router) + defer RemoveTestDatabase(database) + + var response map[string]string + test_utils.MakePostRequestAndUnmarshal( + t, + router, + "/api/v1/databases/"+database.ID.String()+"/regenerate-token", + "Bearer "+owner.Token, + nil, + http.StatusOK, + &response, + ) + + assert.NotEmpty(t, response["token"]) + assert.Len(t, response["token"], 32) + + var updatedDatabase Database + test_utils.MakeGetRequestAndUnmarshal( + t, + router, + "/api/v1/databases/"+database.ID.String(), + "Bearer "+owner.Token, + http.StatusOK, + &updatedDatabase, + ) + assert.True(t, updatedDatabase.IsAgentTokenGenerated) +} + +func Test_VerifyAgentToken_WithValidToken_Succeeds(t *testing.T) { + router := createTestRouter() + owner := users_testing.CreateTestUser(users_enums.UserRoleMember) + workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router) + defer workspaces_testing.RemoveTestWorkspace(workspace, router) + + database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router) + defer RemoveTestDatabase(database) + + var regenerateResponse map[string]string + test_utils.MakePostRequestAndUnmarshal( + t, + router, + "/api/v1/databases/"+database.ID.String()+"/regenerate-token", + "Bearer "+owner.Token, + nil, + http.StatusOK, + ®enerateResponse, + ) + + token := regenerateResponse["token"] + assert.NotEmpty(t, token) + + w := workspaces_testing.MakeAPIRequest( + router, + "POST", + "/api/v1/databases/verify-token", + "", + VerifyAgentTokenRequest{Token: token}, + ) + assert.Equal(t, http.StatusOK, w.Code) +} + +func Test_VerifyAgentToken_WithInvalidToken_Returns401(t *testing.T) { + router := createTestRouter() + + w := workspaces_testing.MakeAPIRequest( + router, + "POST", + "/api/v1/databases/verify-token", + "", + VerifyAgentTokenRequest{Token: "invalidtoken00000000000000000000"}, + ) + assert.Equal(t, http.StatusUnauthorized, w.Code) +} + func createTestDatabaseViaAPI( name string, workspaceID uuid.UUID, @@ -1101,11 +1291,20 @@ func createTestDatabaseViaAPI( } func createTestRouter() *gin.Engine { - router := workspaces_testing.CreateTestRouter( - workspaces_controllers.GetWorkspaceController(), - workspaces_controllers.GetMembershipController(), - GetDatabaseController(), - ) + gin.SetMode(gin.TestMode) + router := gin.New() + + v1 := router.Group("/api/v1") + protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService())) + + workspaces_controllers.GetWorkspaceController().RegisterRoutes(protected.(*gin.RouterGroup)) + workspaces_controllers.GetMembershipController().RegisterRoutes(protected.(*gin.RouterGroup)) + GetDatabaseController().RegisterRoutes(protected.(*gin.RouterGroup)) + + GetDatabaseController().RegisterPublicRoutes(v1) + + audit_logs.SetupDependencies() + return router } @@ -1118,13 +1317,14 @@ func getTestPostgresConfig() *postgresql.PostgresqlDatabase { testDbName := "testdb" return &postgresql.PostgresqlDatabase{ - Version: tools.PostgresqlVersion16, - Host: config.GetEnv().TestLocalhost, - Port: port, - Username: "testuser", - Password: "testpassword", - Database: &testDbName, - CpuCount: 1, + BackupType: postgresql.PostgresBackupTypePgDump, + Version: tools.PostgresqlVersion16, + Host: config.GetEnv().TestLocalhost, + Port: port, + Username: "testuser", + Password: "testpassword", + Database: &testDbName, + CpuCount: 1, } } diff --git a/backend/internal/features/databases/databases/postgresql/model.go b/backend/internal/features/databases/databases/postgresql/model.go index e2a9fa5..7d8a0ba 100644 --- a/backend/internal/features/databases/databases/postgresql/model.go +++ b/backend/internal/features/databases/databases/postgresql/model.go @@ -2,6 +2,7 @@ package postgresql import ( "context" + "databasus-backend/internal/config" "databasus-backend/internal/util/encryption" "databasus-backend/internal/util/tools" "errors" @@ -17,6 +18,13 @@ import ( "gorm.io/gorm" ) +type PostgresBackupType string + +const ( + PostgresBackupTypePgDump PostgresBackupType = "PG_DUMP" + PostgresBackupTypeWalV1 PostgresBackupType = "WAL_V1" +) + type PostgresqlDatabase struct { ID uuid.UUID `json:"id" gorm:"primaryKey;type:uuid;default:gen_random_uuid()"` @@ -24,11 +32,13 @@ type PostgresqlDatabase struct { Version tools.PostgresqlVersion `json:"version" gorm:"type:text;not null"` - // connection data - Host string `json:"host" gorm:"type:text;not null"` - Port int `json:"port" gorm:"type:int;not null"` - Username string `json:"username" gorm:"type:text;not null"` - Password string `json:"password" gorm:"type:text;not null"` + BackupType PostgresBackupType `json:"backupType" gorm:"column:backup_type;type:text;not null;default:'PG_DUMP'"` + + // connection data — required for PG_DUMP, optional for WAL_V1 + Host string `json:"host" gorm:"type:text"` + Port int `json:"port" gorm:"type:int"` + Username string `json:"username" gorm:"type:text"` + Password string `json:"password" gorm:"type:text"` Database *string `json:"database" gorm:"type:text"` IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"` @@ -66,20 +76,30 @@ func (p *PostgresqlDatabase) AfterFind(_ *gorm.DB) error { } func (p *PostgresqlDatabase) Validate() error { - if p.Host == "" { - return errors.New("host is required") + if p.BackupType == "" { + p.BackupType = PostgresBackupTypePgDump } - if p.Port == 0 { - return errors.New("port is required") + if p.BackupType == PostgresBackupTypePgDump && config.GetEnv().IsCloud { + return errors.New("PG_DUMP backup type is not supported in cloud mode") } - if p.Username == "" { - return errors.New("username is required") - } + if p.BackupType == PostgresBackupTypePgDump { + if p.Host == "" { + return errors.New("host is required") + } - if p.Password == "" { - return errors.New("password is required") + if p.Port == 0 { + return errors.New("port is required") + } + + if p.Username == "" { + return errors.New("username is required") + } + + if p.Password == "" { + return errors.New("password is required") + } } if p.CpuCount <= 0 { @@ -90,7 +110,7 @@ func (p *PostgresqlDatabase) Validate() error { // Databasus runs an internal PostgreSQL instance that should not be backed up through the UI // because it would expose internal metadata to non-system administrators. // To properly backup Databasus, see: https://databasus.com/faq#backup-databasus - if p.Database != nil && *p.Database != "" { + if p.BackupType == PostgresBackupTypePgDump && p.Database != nil && *p.Database != "" { localhostHosts := []string{ "localhost", "127.0.0.1", @@ -130,6 +150,10 @@ func (p *PostgresqlDatabase) TestConnection( encryptor encryption.FieldEncryptor, databaseID uuid.UUID, ) error { + if p.BackupType == PostgresBackupTypeWalV1 { + return errors.New("test connection is not supported for WAL backup type") + } + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() @@ -144,7 +168,21 @@ func (p *PostgresqlDatabase) HideSensitiveData() { p.Password = "" } +func (p *PostgresqlDatabase) ValidateUpdate(old *PostgresqlDatabase) error { + // BackupType cannot be changed after creation — the full backup structure + // (WAL hierarchy, storage files, cleanup logic) is built around + // the type chosen at creation time. Automatically migrating this state is + // error-prone; it is safer for the user to create a new database and + // remove the old one. + if old.BackupType != p.BackupType { + return errors.New("backup type cannot be changed; create a new database instead") + } + + return nil +} + func (p *PostgresqlDatabase) Update(incoming *PostgresqlDatabase) { + p.BackupType = incoming.BackupType p.Version = incoming.Version p.Host = incoming.Host p.Port = incoming.Port @@ -181,6 +219,10 @@ func (p *PostgresqlDatabase) PopulateDbData( encryptor encryption.FieldEncryptor, databaseID uuid.UUID, ) error { + if p.BackupType == PostgresBackupTypeWalV1 { + return nil + } + return p.PopulateVersion(logger, encryptor, databaseID) } @@ -243,6 +285,10 @@ func (p *PostgresqlDatabase) IsUserReadOnly( encryptor encryption.FieldEncryptor, databaseID uuid.UUID, ) (bool, []string, error) { + if p.BackupType == PostgresBackupTypeWalV1 { + return false, nil, errors.New("read-only check is not supported for WAL backup type") + } + password, err := decryptPasswordIfNeeded(p.Password, encryptor, databaseID) if err != nil { return false, nil, fmt.Errorf("failed to decrypt password: %w", err) @@ -415,6 +461,10 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser( encryptor encryption.FieldEncryptor, databaseID uuid.UUID, ) (string, string, error) { + if p.BackupType == PostgresBackupTypeWalV1 { + return "", "", errors.New("read-only user creation is not supported for WAL backup type") + } + password, err := decryptPasswordIfNeeded(p.Password, encryptor, databaseID) if err != nil { return "", "", fmt.Errorf("failed to decrypt password: %w", err) diff --git a/backend/internal/features/databases/dto.go b/backend/internal/features/databases/dto.go index 63b5a16..38ae46c 100644 --- a/backend/internal/features/databases/dto.go +++ b/backend/internal/features/databases/dto.go @@ -9,3 +9,7 @@ type IsReadOnlyResponse struct { IsReadOnly bool `json:"isReadOnly"` Privileges []string `json:"privileges"` } + +type VerifyAgentTokenRequest struct { + Token string `json:"token" binding:"required"` +} diff --git a/backend/internal/features/databases/model.go b/backend/internal/features/databases/model.go index e325ce4..7306b59 100644 --- a/backend/internal/features/databases/model.go +++ b/backend/internal/features/databases/model.go @@ -37,6 +37,9 @@ type Database struct { LastBackupErrorMessage *string `json:"lastBackupErrorMessage,omitempty" gorm:"column:last_backup_error_message;type:text"` HealthStatus *HealthStatus `json:"healthStatus" gorm:"column:health_status;type:text;not null"` + + AgentToken *string `json:"-" gorm:"column:agent_token;type:text"` + IsAgentTokenGenerated bool `json:"isAgentTokenGenerated" gorm:"column:is_agent_token_generated;not null;default:false"` } func (d *Database) Validate() error { @@ -71,8 +74,19 @@ func (d *Database) Validate() error { } func (d *Database) ValidateUpdate(old, new Database) error { + // Database type cannot be changed after creation — the entire backup + // structure (storage files, schedulers, WAL hierarchy, etc.) is tied to + // the type at creation time. Recreating that state automatically is + // error-prone; it is safer for the user to create a new database and + // remove the old one. if old.Type != new.Type { - return errors.New("database type is not allowed to change") + return errors.New("database type cannot be changed; create a new database instead") + } + + if old.Type == DatabaseTypePostgres && old.Postgresql != nil && new.Postgresql != nil { + if err := new.Postgresql.ValidateUpdate(old.Postgresql); err != nil { + return err + } } return nil diff --git a/backend/internal/features/databases/repository.go b/backend/internal/features/databases/repository.go index c5dc8a7..1277c05 100644 --- a/backend/internal/features/databases/repository.go +++ b/backend/internal/features/databases/repository.go @@ -244,6 +244,18 @@ func (r *DatabaseRepository) GetAllDatabases() ([]*Database, error) { return databases, nil } +func (r *DatabaseRepository) FindByAgentTokenHash(hash string) (*Database, error) { + var database Database + + if err := storage.GetDb(). + Where("agent_token = ?", hash). + First(&database).Error; err != nil { + return nil, err + } + + return &database, nil +} + func (r *DatabaseRepository) GetDatabasesIDsByNotifierID( notifierID uuid.UUID, ) ([]uuid.UUID, error) { diff --git a/backend/internal/features/databases/service.go b/backend/internal/features/databases/service.go index 9392bac..6ad99aa 100644 --- a/backend/internal/features/databases/service.go +++ b/backend/internal/features/databases/service.go @@ -2,9 +2,11 @@ package databases import ( "context" + "crypto/sha256" "errors" "fmt" "log/slog" + "strings" "time" "databasus-backend/internal/config" @@ -87,21 +89,8 @@ func (s *DatabaseService) CreateDatabase( return nil, fmt.Errorf("failed to auto-detect database data: %w", err) } - if config.GetEnv().IsCloud { - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) - defer cancel() - - isReadOnly, permissions, err := database.IsUserReadOnly(ctx, s.logger, s.fieldEncryptor) - if err != nil { - return nil, fmt.Errorf("failed to verify user permissions: %w", err) - } - - if !isReadOnly { - return nil, fmt.Errorf( - "in cloud mode, only read-only database users are allowed (user has permissions: %v)", - permissions, - ) - } + if err := s.verifyReadOnlyUserIfNeeded(database); err != nil { + return nil, err } if err := database.EncryptSensitiveFields(s.fieldEncryptor); err != nil { @@ -171,25 +160,8 @@ func (s *DatabaseService) UpdateDatabase( return fmt.Errorf("failed to auto-detect database data: %w", err) } - if config.GetEnv().IsCloud { - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) - defer cancel() - - isReadOnly, permissions, err := existingDatabase.IsUserReadOnly( - ctx, - s.logger, - s.fieldEncryptor, - ) - if err != nil { - return fmt.Errorf("failed to verify user permissions: %w", err) - } - - if !isReadOnly { - return fmt.Errorf( - "in cloud mode, only read-only database users are allowed (user has permissions: %v)", - permissions, - ) - } + if err := s.verifyReadOnlyUserIfNeeded(existingDatabase); err != nil { + return err } oldName := existingDatabase.Name @@ -485,6 +457,7 @@ func (s *DatabaseService) CopyDatabase( newDatabase.Postgresql = &postgresql.PostgresqlDatabase{ ID: uuid.Nil, DatabaseID: nil, + BackupType: existingDatabase.Postgresql.BackupType, Version: existingDatabase.Postgresql.Version, Host: existingDatabase.Postgresql.Host, Port: existingDatabase.Postgresql.Port, @@ -638,6 +611,71 @@ func (s *DatabaseService) SetHealthStatus( return nil } +func (s *DatabaseService) RegenerateAgentToken( + user *users_models.User, + databaseID uuid.UUID, +) (string, error) { + database, err := s.dbRepository.FindByID(databaseID) + if err != nil { + return "", err + } + + if database.WorkspaceID == nil { + return "", errors.New("cannot regenerate token for database without workspace") + } + + canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user) + if err != nil { + return "", err + } + if !canManage { + return "", errors.New( + "insufficient permissions to regenerate agent token for this database", + ) + } + + plainToken := strings.ReplaceAll(uuid.New().String(), "-", "") + tokenHash := hashAgentToken(plainToken) + + database.AgentToken = &tokenHash + database.IsAgentTokenGenerated = true + + _, err = s.dbRepository.Save(database) + if err != nil { + return "", err + } + + s.auditLogService.WriteAuditLog( + fmt.Sprintf("Agent token regenerated for database: %s", database.Name), + &user.ID, + database.WorkspaceID, + ) + + return plainToken, nil +} + +func (s *DatabaseService) VerifyAgentToken(token string) error { + hash := hashAgentToken(token) + + _, err := s.dbRepository.FindByAgentTokenHash(hash) + if err != nil { + return errors.New("invalid token") + } + + return nil +} + +func (s *DatabaseService) GetDatabaseByAgentToken(token string) (*Database, error) { + hash := hashAgentToken(token) + + partial, err := s.dbRepository.FindByAgentTokenHash(hash) + if err != nil { + return nil, errors.New("invalid agent token") + } + + return s.dbRepository.FindByID(partial.ID) +} + func (s *DatabaseService) OnBeforeWorkspaceDeletion(workspaceID uuid.UUID) error { databases, err := s.dbRepository.FindByWorkspaceID(workspaceID) if err != nil { @@ -809,3 +847,36 @@ func (s *DatabaseService) CreateReadOnlyUser( return username, password, nil } + +func (s *DatabaseService) verifyReadOnlyUserIfNeeded(database *Database) error { + if !config.GetEnv().IsCloud { + return nil + } + + if database.Postgresql != nil && + database.Postgresql.BackupType == postgresql.PostgresBackupTypeWalV1 { + return nil + } + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + isReadOnly, permissions, err := database.IsUserReadOnly(ctx, s.logger, s.fieldEncryptor) + if err != nil { + return fmt.Errorf("failed to verify user permissions: %w", err) + } + + if !isReadOnly { + return fmt.Errorf( + "in cloud mode, only read-only database users are allowed (user has permissions: %v)", + permissions, + ) + } + + return nil +} + +func hashAgentToken(token string) string { + hash := sha256.Sum256([]byte(token)) + return fmt.Sprintf("%x", hash) +} diff --git a/backend/internal/features/intervals/model.go b/backend/internal/features/intervals/model.go index 921aa96..d852e17 100644 --- a/backend/internal/features/intervals/model.go +++ b/backend/internal/features/intervals/model.go @@ -79,6 +79,38 @@ func (i *Interval) ShouldTriggerBackup(now time.Time, lastBackupTime *time.Time) } } +// NextTriggerTime computes the next time a backup should trigger based on the interval and last backup time. +// Returns nil when a backup is due immediately (no previous backup exists). +func (i *Interval) NextTriggerTime(now time.Time, lastBackupTime *time.Time) *time.Time { + if lastBackupTime == nil { + return nil + } + + switch i.Interval { + case IntervalHourly: + next := lastBackupTime.Add(time.Hour) + return &next + + case IntervalDaily: + next := i.nextDailyTrigger(now) + return &next + + case IntervalWeekly: + next := i.nextWeeklyTrigger(now) + return &next + + case IntervalMonthly: + next := i.nextMonthlyTrigger(now) + return &next + + case IntervalCron: + return i.nextCronTrigger(*lastBackupTime) + + default: + return nil + } +} + func (i *Interval) Copy() *Interval { return &Interval{ ID: uuid.Nil, @@ -240,6 +272,99 @@ func (i *Interval) shouldTriggerCron(now, lastBackup time.Time) bool { return now.After(nextAfterLastBackup) || now.Equal(nextAfterLastBackup) } +func (i *Interval) nextDailyTrigger(now time.Time) time.Time { + t, err := time.Parse("15:04", *i.TimeOfDay) + if err != nil { + return now + } + + todaySlot := time.Date( + now.Year(), now.Month(), now.Day(), + t.Hour(), t.Minute(), 0, 0, now.Location(), + ) + + if now.Before(todaySlot) { + return todaySlot + } + + return todaySlot.AddDate(0, 0, 1) +} + +func (i *Interval) nextWeeklyTrigger(now time.Time) time.Time { + targetWd := time.Weekday(0) + if i.Weekday != nil { + targetWd = time.Weekday(*i.Weekday) + } + + startOfWeek := getStartOfWeek(now) + + var daysFromMonday int + if targetWd == time.Sunday { + daysFromMonday = 6 + } else { + daysFromMonday = int(targetWd) - 1 + } + + targetThisWeek := startOfWeek.AddDate(0, 0, daysFromMonday) + + if i.TimeOfDay != nil { + t, err := time.Parse("15:04", *i.TimeOfDay) + if err == nil { + targetThisWeek = time.Date( + targetThisWeek.Year(), targetThisWeek.Month(), targetThisWeek.Day(), + t.Hour(), t.Minute(), 0, 0, targetThisWeek.Location(), + ) + } + } + + if now.Before(targetThisWeek) { + return targetThisWeek + } + + return targetThisWeek.AddDate(0, 0, 7) +} + +func (i *Interval) nextMonthlyTrigger(now time.Time) time.Time { + day := 1 + if i.DayOfMonth != nil { + day = *i.DayOfMonth + } + + targetThisMonth := time.Date(now.Year(), now.Month(), day, 0, 0, 0, 0, now.Location()) + + if i.TimeOfDay != nil { + t, err := time.Parse("15:04", *i.TimeOfDay) + if err == nil { + targetThisMonth = time.Date( + targetThisMonth.Year(), targetThisMonth.Month(), targetThisMonth.Day(), + t.Hour(), t.Minute(), 0, 0, targetThisMonth.Location(), + ) + } + } + + if now.Before(targetThisMonth) { + return targetThisMonth + } + + return targetThisMonth.AddDate(0, 1, 0) +} + +func (i *Interval) nextCronTrigger(lastBackup time.Time) *time.Time { + if i.CronExpression == nil || *i.CronExpression == "" { + return nil + } + + parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow) + schedule, err := parser.Parse(*i.CronExpression) + if err != nil { + return nil + } + + next := schedule.Next(lastBackup) + + return &next +} + func (i *Interval) validateCronExpression(expr string) error { parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow) _, err := parser.Parse(expr) diff --git a/backend/internal/features/intervals/model_test.go b/backend/internal/features/intervals/model_test.go index b5197d5..fdd0022 100644 --- a/backend/internal/features/intervals/model_test.go +++ b/backend/internal/features/intervals/model_test.go @@ -721,3 +721,265 @@ func TestInterval_Validate(t *testing.T) { assert.NoError(t, err) }) } + +func TestInterval_NextTriggerTime_NilLastBackup(t *testing.T) { + now := time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC) + + t.Run("Hourly with nil lastBackup returns nil", func(t *testing.T) { + interval := &Interval{ID: uuid.New(), Interval: IntervalHourly} + result := interval.NextTriggerTime(now, nil) + assert.Nil(t, result) + }) + + t.Run("Daily with nil lastBackup returns nil", func(t *testing.T) { + timeOfDay := "09:00" + interval := &Interval{ID: uuid.New(), Interval: IntervalDaily, TimeOfDay: &timeOfDay} + result := interval.NextTriggerTime(now, nil) + assert.Nil(t, result) + }) + + t.Run("Weekly with nil lastBackup returns nil", func(t *testing.T) { + timeOfDay := "15:00" + weekday := 3 + interval := &Interval{ + ID: uuid.New(), + Interval: IntervalWeekly, + TimeOfDay: &timeOfDay, + Weekday: &weekday, + } + result := interval.NextTriggerTime(now, nil) + assert.Nil(t, result) + }) + + t.Run("Monthly with nil lastBackup returns nil", func(t *testing.T) { + timeOfDay := "08:00" + dayOfMonth := 10 + interval := &Interval{ + ID: uuid.New(), + Interval: IntervalMonthly, + TimeOfDay: &timeOfDay, + DayOfMonth: &dayOfMonth, + } + result := interval.NextTriggerTime(now, nil) + assert.Nil(t, result) + }) + + t.Run("Cron with nil lastBackup returns nil", func(t *testing.T) { + cronExpr := "0 2 * * *" + interval := &Interval{ID: uuid.New(), Interval: IntervalCron, CronExpression: &cronExpr} + result := interval.NextTriggerTime(now, nil) + assert.Nil(t, result) + }) +} + +func TestInterval_NextTriggerTime_Hourly(t *testing.T) { + interval := &Interval{ID: uuid.New(), Interval: IntervalHourly} + now := time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC) + + t.Run("Returns lastBackup + 1 hour", func(t *testing.T) { + lastBackup := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC) + result := interval.NextTriggerTime(now, &lastBackup) + + assert.NotNil(t, result) + assert.Equal(t, time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC), *result) + }) + + t.Run("Returns future time when last backup was recent", func(t *testing.T) { + lastBackup := time.Date(2024, 1, 15, 11, 30, 0, 0, time.UTC) + result := interval.NextTriggerTime(now, &lastBackup) + + assert.NotNil(t, result) + assert.Equal(t, time.Date(2024, 1, 15, 12, 30, 0, 0, time.UTC), *result) + }) +} + +func TestInterval_NextTriggerTime_Daily(t *testing.T) { + timeOfDay := "09:00" + interval := &Interval{ID: uuid.New(), Interval: IntervalDaily, TimeOfDay: &timeOfDay} + + t.Run("Before today's slot: returns today's slot", func(t *testing.T) { + now := time.Date(2024, 1, 15, 8, 0, 0, 0, time.UTC) + lastBackup := time.Date(2024, 1, 14, 9, 0, 0, 0, time.UTC) + result := interval.NextTriggerTime(now, &lastBackup) + + assert.NotNil(t, result) + assert.Equal(t, time.Date(2024, 1, 15, 9, 0, 0, 0, time.UTC), *result) + }) + + t.Run("After today's slot: returns tomorrow's slot", func(t *testing.T) { + now := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC) + lastBackup := time.Date(2024, 1, 15, 9, 0, 0, 0, time.UTC) + result := interval.NextTriggerTime(now, &lastBackup) + + assert.NotNil(t, result) + assert.Equal(t, time.Date(2024, 1, 16, 9, 0, 0, 0, time.UTC), *result) + }) + + t.Run("Exactly at today's slot: returns tomorrow's slot", func(t *testing.T) { + now := time.Date(2024, 1, 15, 9, 0, 0, 0, time.UTC) + lastBackup := time.Date(2024, 1, 14, 9, 0, 0, 0, time.UTC) + result := interval.NextTriggerTime(now, &lastBackup) + + assert.NotNil(t, result) + assert.Equal(t, time.Date(2024, 1, 16, 9, 0, 0, 0, time.UTC), *result) + }) +} + +func TestInterval_NextTriggerTime_Weekly(t *testing.T) { + timeOfDay := "15:00" + weekday := 3 // Wednesday + interval := &Interval{ + ID: uuid.New(), + Interval: IntervalWeekly, + TimeOfDay: &timeOfDay, + Weekday: &weekday, + } + + t.Run("Before this week's target: returns this week's target", func(t *testing.T) { + // Tuesday Jan 16, 2024 + now := time.Date(2024, 1, 16, 10, 0, 0, 0, time.UTC) + lastBackup := time.Date(2024, 1, 10, 15, 0, 0, 0, time.UTC) + result := interval.NextTriggerTime(now, &lastBackup) + + assert.NotNil(t, result) + // Wednesday Jan 17 at 15:00 + assert.Equal(t, time.Date(2024, 1, 17, 15, 0, 0, 0, time.UTC), *result) + }) + + t.Run("After this week's target: returns next week's target", func(t *testing.T) { + // Thursday Jan 18, 2024 + now := time.Date(2024, 1, 18, 10, 0, 0, 0, time.UTC) + lastBackup := time.Date(2024, 1, 17, 15, 0, 0, 0, time.UTC) + result := interval.NextTriggerTime(now, &lastBackup) + + assert.NotNil(t, result) + // Next Wednesday Jan 24 at 15:00 + assert.Equal(t, time.Date(2024, 1, 24, 15, 0, 0, 0, time.UTC), *result) + }) + + t.Run("Friday interval: returns correct target", func(t *testing.T) { + fridayTimeOfDay := "00:00" + fridayWeekday := 5 // Friday + fridayInterval := &Interval{ + ID: uuid.New(), + Interval: IntervalWeekly, + TimeOfDay: &fridayTimeOfDay, + Weekday: &fridayWeekday, + } + + // Wednesday Jan 17, 2024 + now := time.Date(2024, 1, 17, 10, 0, 0, 0, time.UTC) + lastBackup := time.Date(2024, 1, 12, 0, 0, 0, 0, time.UTC) + result := fridayInterval.NextTriggerTime(now, &lastBackup) + + assert.NotNil(t, result) + // Friday Jan 19 at 00:00 + assert.Equal(t, time.Date(2024, 1, 19, 0, 0, 0, 0, time.UTC), *result) + }) +} + +func TestInterval_NextTriggerTime_Monthly(t *testing.T) { + timeOfDay := "08:00" + dayOfMonth := 10 + interval := &Interval{ + ID: uuid.New(), + Interval: IntervalMonthly, + TimeOfDay: &timeOfDay, + DayOfMonth: &dayOfMonth, + } + + t.Run("Before this month's target: returns this month's target", func(t *testing.T) { + now := time.Date(2024, 1, 5, 10, 0, 0, 0, time.UTC) + lastBackup := time.Date(2023, 12, 10, 8, 0, 0, 0, time.UTC) + result := interval.NextTriggerTime(now, &lastBackup) + + assert.NotNil(t, result) + assert.Equal(t, time.Date(2024, 1, 10, 8, 0, 0, 0, time.UTC), *result) + }) + + t.Run("After this month's target: returns next month's target", func(t *testing.T) { + now := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC) + lastBackup := time.Date(2024, 1, 10, 8, 0, 0, 0, time.UTC) + result := interval.NextTriggerTime(now, &lastBackup) + + assert.NotNil(t, result) + assert.Equal(t, time.Date(2024, 2, 10, 8, 0, 0, 0, time.UTC), *result) + }) + + t.Run("Exactly at this month's target: returns next month's target", func(t *testing.T) { + now := time.Date(2024, 1, 10, 8, 0, 0, 0, time.UTC) + lastBackup := time.Date(2023, 12, 10, 8, 0, 0, 0, time.UTC) + result := interval.NextTriggerTime(now, &lastBackup) + + assert.NotNil(t, result) + assert.Equal(t, time.Date(2024, 2, 10, 8, 0, 0, 0, time.UTC), *result) + }) +} + +func TestInterval_NextTriggerTime_Cron(t *testing.T) { + t.Run("Daily cron: returns next trigger after lastBackup", func(t *testing.T) { + cronExpr := "0 2 * * *" // Daily at 2:00 AM + interval := &Interval{ID: uuid.New(), Interval: IntervalCron, CronExpression: &cronExpr} + + now := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC) + lastBackup := time.Date(2024, 1, 14, 2, 0, 0, 0, time.UTC) + result := interval.NextTriggerTime(now, &lastBackup) + + assert.NotNil(t, result) + assert.Equal(t, time.Date(2024, 1, 15, 2, 0, 0, 0, time.UTC), *result) + }) + + t.Run("Complex cron: 1st and 15th at 4:30", func(t *testing.T) { + cronExpr := "30 4 1,15 * *" + interval := &Interval{ID: uuid.New(), Interval: IntervalCron, CronExpression: &cronExpr} + + now := time.Date(2024, 1, 10, 10, 0, 0, 0, time.UTC) + lastBackup := time.Date(2024, 1, 1, 4, 30, 0, 0, time.UTC) + result := interval.NextTriggerTime(now, &lastBackup) + + assert.NotNil(t, result) + assert.Equal(t, time.Date(2024, 1, 15, 4, 30, 0, 0, time.UTC), *result) + }) + + t.Run("Invalid cron expression returns nil", func(t *testing.T) { + invalidCron := "invalid cron" + interval := &Interval{ID: uuid.New(), Interval: IntervalCron, CronExpression: &invalidCron} + + now := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC) + lastBackup := time.Date(2024, 1, 14, 10, 0, 0, 0, time.UTC) + result := interval.NextTriggerTime(now, &lastBackup) + + assert.Nil(t, result) + }) + + t.Run("Empty cron expression returns nil", func(t *testing.T) { + emptyCron := "" + interval := &Interval{ID: uuid.New(), Interval: IntervalCron, CronExpression: &emptyCron} + + now := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC) + lastBackup := time.Date(2024, 1, 14, 10, 0, 0, 0, time.UTC) + result := interval.NextTriggerTime(now, &lastBackup) + + assert.Nil(t, result) + }) + + t.Run("Nil cron expression returns nil", func(t *testing.T) { + interval := &Interval{ID: uuid.New(), Interval: IntervalCron, CronExpression: nil} + + now := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC) + lastBackup := time.Date(2024, 1, 14, 10, 0, 0, 0, time.UTC) + result := interval.NextTriggerTime(now, &lastBackup) + + assert.Nil(t, result) + }) +} + +func TestInterval_NextTriggerTime_UnknownInterval(t *testing.T) { + interval := &Interval{ID: uuid.New(), Interval: IntervalType("UNKNOWN")} + + now := time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC) + lastBackup := time.Date(2024, 1, 14, 12, 0, 0, 0, time.UTC) + result := interval.NextTriggerTime(now, &lastBackup) + + assert.Nil(t, result) +} diff --git a/backend/internal/features/restores/controller_test.go b/backend/internal/features/restores/controller_test.go index f17e0ac..7951b0b 100644 --- a/backend/internal/features/restores/controller_test.go +++ b/backend/internal/features/restores/controller_test.go @@ -18,7 +18,7 @@ import ( env_config "databasus-backend/internal/config" audit_logs "databasus-backend/internal/features/audit_logs" - "databasus-backend/internal/features/backups/backups" + backups_controllers "databasus-backend/internal/features/backups/backups/controllers" backups_core "databasus-backend/internal/features/backups/backups/core" backups_config "databasus-backend/internal/features/backups/config" "databasus-backend/internal/features/databases" @@ -440,7 +440,7 @@ func Test_CancelRestore_InProgressRestore_SuccessfullyCancelled(t *testing.T) { }() backups_config.EnableBackupsForTestDatabase(database.ID, storage) - backup := backups.CreateTestBackup(database.ID, storage.ID) + backup := backups_controllers.CreateTestBackup(database.ID, storage.ID) mockUsecase := &restoring.MockBlockingRestoreUsecase{ StartedChan: make(chan bool, 1), diff --git a/backend/internal/features/restores/di.go b/backend/internal/features/restores/di.go index 483549e..b74cc12 100644 --- a/backend/internal/features/restores/di.go +++ b/backend/internal/features/restores/di.go @@ -5,8 +5,8 @@ import ( "sync/atomic" audit_logs "databasus-backend/internal/features/audit_logs" - "databasus-backend/internal/features/backups/backups" "databasus-backend/internal/features/backups/backups/backuping" + backups_services "databasus-backend/internal/features/backups/backups/services" backups_config "databasus-backend/internal/features/backups/config" "databasus-backend/internal/features/databases" "databasus-backend/internal/features/disk" @@ -21,7 +21,7 @@ import ( var restoreRepository = &restores_core.RestoreRepository{} var restoreService = &RestoreService{ - backups.GetBackupService(), + backups_services.GetBackupService(), restoreRepository, storages.GetStorageService(), backups_config.GetBackupConfigService(), @@ -51,7 +51,7 @@ func SetupDependencies() { wasAlreadySetup := isSetup.Load() setupOnce.Do(func() { - backups.GetBackupService().AddBackupRemoveListener(restoreService) + backups_services.GetBackupService().AddBackupRemoveListener(restoreService) backuping.GetBackupCleaner().AddBackupRemoveListener(restoreService) isSetup.Store(true) diff --git a/backend/internal/features/restores/restoring/di.go b/backend/internal/features/restores/restoring/di.go index 29b850d..c2000ac 100644 --- a/backend/internal/features/restores/restoring/di.go +++ b/backend/internal/features/restores/restoring/di.go @@ -7,7 +7,7 @@ import ( "github.com/google/uuid" - "databasus-backend/internal/features/backups/backups" + backups_services "databasus-backend/internal/features/backups/backups/services" backups_config "databasus-backend/internal/features/backups/config" "databasus-backend/internal/features/databases" restores_core "databasus-backend/internal/features/restores/core" @@ -39,37 +39,37 @@ var restoreDatabaseCache = cache_utils.NewCacheUtil[RestoreDatabaseCache]( var restoreCancelManager = tasks_cancellation.GetTaskCancelManager() var restorerNode = &RestorerNode{ - nodeID: uuid.New(), - databaseService: databases.GetDatabaseService(), - backupService: backups.GetBackupService(), - fieldEncryptor: encryption.GetFieldEncryptor(), - restoreRepository: restoreRepository, - backupConfigService: backups_config.GetBackupConfigService(), - storageService: storages.GetStorageService(), - restoreNodesRegistry: restoreNodesRegistry, - logger: logger.GetLogger(), - restoreBackupUsecase: usecases.GetRestoreBackupUsecase(), - cacheUtil: restoreDatabaseCache, - restoreCancelManager: restoreCancelManager, - lastHeartbeat: time.Time{}, - runOnce: sync.Once{}, - hasRun: atomic.Bool{}, + uuid.New(), + databases.GetDatabaseService(), + backups_services.GetBackupService(), + encryption.GetFieldEncryptor(), + restoreRepository, + backups_config.GetBackupConfigService(), + storages.GetStorageService(), + restoreNodesRegistry, + logger.GetLogger(), + usecases.GetRestoreBackupUsecase(), + restoreDatabaseCache, + restoreCancelManager, + time.Time{}, + sync.Once{}, + atomic.Bool{}, } var restoresScheduler = &RestoresScheduler{ - restoreRepository: restoreRepository, - backupService: backups.GetBackupService(), - storageService: storages.GetStorageService(), - backupConfigService: backups_config.GetBackupConfigService(), - restoreNodesRegistry: restoreNodesRegistry, - lastCheckTime: time.Now().UTC(), - logger: logger.GetLogger(), - restoreToNodeRelations: make(map[uuid.UUID]RestoreToNodeRelation), - restorerNode: restorerNode, - cacheUtil: restoreDatabaseCache, - completionSubscriptionID: uuid.Nil, - runOnce: sync.Once{}, - hasRun: atomic.Bool{}, + restoreRepository, + backups_services.GetBackupService(), + storages.GetStorageService(), + backups_config.GetBackupConfigService(), + restoreNodesRegistry, + time.Now().UTC(), + logger.GetLogger(), + make(map[uuid.UUID]RestoreToNodeRelation), + restorerNode, + restoreDatabaseCache, + uuid.Nil, + sync.Once{}, + atomic.Bool{}, } func GetRestoresScheduler() *RestoresScheduler { diff --git a/backend/internal/features/restores/restoring/restorer.go b/backend/internal/features/restores/restoring/restorer.go index 2947aa7..092df7a 100644 --- a/backend/internal/features/restores/restoring/restorer.go +++ b/backend/internal/features/restores/restoring/restorer.go @@ -13,7 +13,7 @@ import ( "github.com/google/uuid" "databasus-backend/internal/config" - "databasus-backend/internal/features/backups/backups" + backups_services "databasus-backend/internal/features/backups/backups/services" backups_config "databasus-backend/internal/features/backups/config" "databasus-backend/internal/features/databases" restores_core "databasus-backend/internal/features/restores/core" @@ -32,7 +32,7 @@ type RestorerNode struct { nodeID uuid.UUID databaseService *databases.DatabaseService - backupService *backups.BackupService + backupService *backups_services.BackupService fieldEncryptor util_encryption.FieldEncryptor restoreRepository *restores_core.RestoreRepository backupConfigService *backups_config.BackupConfigService diff --git a/backend/internal/features/restores/restoring/restorer_test.go b/backend/internal/features/restores/restoring/restorer_test.go index 65137f1..ce315c4 100644 --- a/backend/internal/features/restores/restoring/restorer_test.go +++ b/backend/internal/features/restores/restoring/restorer_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/assert" "databasus-backend/internal/config" - "databasus-backend/internal/features/backups/backups" + backups_controllers "databasus-backend/internal/features/backups/backups/controllers" backups_core "databasus-backend/internal/features/backups/backups/core" backups_config "databasus-backend/internal/features/backups/config" "databasus-backend/internal/features/databases" @@ -58,7 +58,7 @@ func Test_MakeRestore_WhenCacheMissed_RestoreFails(t *testing.T) { cache_utils.ClearAllCache() }() - backup := backups.CreateTestBackup(database.ID, storage.ID) + backup := backups_controllers.CreateTestBackup(database.ID, storage.ID) // Create restore but DON'T cache DB credentials // Also don't set embedded DB fields to avoid schema issues @@ -126,7 +126,7 @@ func Test_MakeRestore_WhenTaskStarts_CacheDeletedImmediately(t *testing.T) { cache_utils.ClearAllCache() }() - backup := backups.CreateTestBackup(database.ID, storage.ID) + backup := backups_controllers.CreateTestBackup(database.ID, storage.ID) // Create restore with cached DB credentials // Don't set embedded DB fields in the restore model itself diff --git a/backend/internal/features/restores/restoring/scheduler.go b/backend/internal/features/restores/restoring/scheduler.go index 361b3a1..58faf62 100644 --- a/backend/internal/features/restores/restoring/scheduler.go +++ b/backend/internal/features/restores/restoring/scheduler.go @@ -11,7 +11,7 @@ import ( "github.com/google/uuid" "databasus-backend/internal/config" - "databasus-backend/internal/features/backups/backups" + backups_services "databasus-backend/internal/features/backups/backups/services" backups_config "databasus-backend/internal/features/backups/config" restores_core "databasus-backend/internal/features/restores/core" "databasus-backend/internal/features/storages" @@ -26,7 +26,7 @@ const ( type RestoresScheduler struct { restoreRepository *restores_core.RestoreRepository - backupService *backups.BackupService + backupService *backups_services.BackupService storageService *storages.StorageService backupConfigService *backups_config.BackupConfigService restoreNodesRegistry *RestoreNodesRegistry diff --git a/backend/internal/features/restores/restoring/scheduler_test.go b/backend/internal/features/restores/restoring/scheduler_test.go index 3b6f353..66899cf 100644 --- a/backend/internal/features/restores/restoring/scheduler_test.go +++ b/backend/internal/features/restores/restoring/scheduler_test.go @@ -5,7 +5,7 @@ import ( "time" "databasus-backend/internal/config" - "databasus-backend/internal/features/backups/backups" + backups_controllers "databasus-backend/internal/features/backups/backups/controllers" backups_core "databasus-backend/internal/features/backups/backups/core" backups_config "databasus-backend/internal/features/backups/config" "databasus-backend/internal/features/databases" @@ -68,7 +68,7 @@ func Test_CheckDeadNodesAndFailRestores_NodeDies_FailsRestoreAndCleansUpRegistry backups_config.EnableBackupsForTestDatabase(database.ID, storage) // Create a test backup - backup := backups.CreateTestBackup(database.ID, storage.ID) + backup := backups_controllers.CreateTestBackup(database.ID, storage.ID) var err error // Register mock node without subscribing to restores (simulates node crash after registration) @@ -171,7 +171,7 @@ func Test_OnRestoreCompleted_TaskIsNotRestore_SkipsProcessing(t *testing.T) { backups_config.EnableBackupsForTestDatabase(database.ID, storage) // Create a test backup - backup := backups.CreateTestBackup(database.ID, storage.ID) + backup := backups_controllers.CreateTestBackup(database.ID, storage.ID) // Register mock node mockNodeID = uuid.New() @@ -357,7 +357,7 @@ func Test_FailRestoresInProgress_SchedulerStarts_UpdatesStatus(t *testing.T) { backups_config.EnableBackupsForTestDatabase(database.ID, storage) // Create a test backup - backup := backups.CreateTestBackup(database.ID, storage.ID) + backup := backups_controllers.CreateTestBackup(database.ID, storage.ID) // Create two in-progress restores that should be failed on scheduler restart restore1 := &restores_core.Restore{ @@ -465,7 +465,7 @@ func Test_StartRestore_RestoreCompletes_DecrementsActiveTaskCount(t *testing.T) backups_config.EnableBackupsForTestDatabase(database.ID, storage) // Create a test backup - backup := backups.CreateTestBackup(database.ID, storage.ID) + backup := backups_controllers.CreateTestBackup(database.ID, storage.ID) // Get initial active task count stats, err := restoreNodesRegistry.GetRestoreNodesStats() @@ -566,7 +566,7 @@ func Test_StartRestore_RestoreFails_DecrementsActiveTaskCount(t *testing.T) { backups_config.EnableBackupsForTestDatabase(database.ID, storage) // Create a test backup - backup := backups.CreateTestBackup(database.ID, storage.ID) + backup := backups_controllers.CreateTestBackup(database.ID, storage.ID) // Get initial active task count stats, err := restoreNodesRegistry.GetRestoreNodesStats() @@ -664,7 +664,7 @@ func Test_StartRestore_CredentialsStoredEncryptedInCache(t *testing.T) { backups_config.EnableBackupsForTestDatabase(database.ID, storage) // Create a test backup - backup := backups.CreateTestBackup(database.ID, storage.ID) + backup := backups_controllers.CreateTestBackup(database.ID, storage.ID) // Register mock node so scheduler can assign restore to it mockNodeID = uuid.New() @@ -779,7 +779,7 @@ func Test_StartRestore_CredentialsRemovedAfterRestoreStarts(t *testing.T) { backups_config.EnableBackupsForTestDatabase(database.ID, storage) // Create a test backup - backup := backups.CreateTestBackup(database.ID, storage.ID) + backup := backups_controllers.CreateTestBackup(database.ID, storage.ID) // Create restore with credentials plaintextPassword := "test_password_456" diff --git a/backend/internal/features/restores/restoring/testing.go b/backend/internal/features/restores/restoring/testing.go index d37c6f4..2db3ce3 100644 --- a/backend/internal/features/restores/restoring/testing.go +++ b/backend/internal/features/restores/restoring/testing.go @@ -12,8 +12,8 @@ import ( "github.com/google/uuid" "databasus-backend/internal/config" - "databasus-backend/internal/features/backups/backups" backups_core "databasus-backend/internal/features/backups/backups/core" + backups_services "databasus-backend/internal/features/backups/backups/services" backups_config "databasus-backend/internal/features/backups/config" "databasus-backend/internal/features/databases" "databasus-backend/internal/features/databases/databases/postgresql" @@ -40,48 +40,48 @@ func CreateTestRouter() *gin.Engine { func CreateTestRestorerNode() *RestorerNode { return &RestorerNode{ - nodeID: uuid.New(), - databaseService: databases.GetDatabaseService(), - backupService: backups.GetBackupService(), - fieldEncryptor: encryption.GetFieldEncryptor(), - restoreRepository: restoreRepository, - backupConfigService: backups_config.GetBackupConfigService(), - storageService: storages.GetStorageService(), - restoreNodesRegistry: restoreNodesRegistry, - logger: logger.GetLogger(), - restoreBackupUsecase: usecases.GetRestoreBackupUsecase(), - cacheUtil: restoreDatabaseCache, - restoreCancelManager: tasks_cancellation.GetTaskCancelManager(), - lastHeartbeat: time.Time{}, - runOnce: sync.Once{}, - hasRun: atomic.Bool{}, + uuid.New(), + databases.GetDatabaseService(), + backups_services.GetBackupService(), + encryption.GetFieldEncryptor(), + restoreRepository, + backups_config.GetBackupConfigService(), + storages.GetStorageService(), + restoreNodesRegistry, + logger.GetLogger(), + usecases.GetRestoreBackupUsecase(), + restoreDatabaseCache, + tasks_cancellation.GetTaskCancelManager(), + time.Time{}, + sync.Once{}, + atomic.Bool{}, } } func CreateTestRestorerNodeWithUsecase(usecase restores_core.RestoreBackupUsecase) *RestorerNode { return &RestorerNode{ - nodeID: uuid.New(), - databaseService: databases.GetDatabaseService(), - backupService: backups.GetBackupService(), - fieldEncryptor: encryption.GetFieldEncryptor(), - restoreRepository: restoreRepository, - backupConfigService: backups_config.GetBackupConfigService(), - storageService: storages.GetStorageService(), - restoreNodesRegistry: restoreNodesRegistry, - logger: logger.GetLogger(), - restoreBackupUsecase: usecase, - cacheUtil: restoreDatabaseCache, - restoreCancelManager: tasks_cancellation.GetTaskCancelManager(), - lastHeartbeat: time.Time{}, - runOnce: sync.Once{}, - hasRun: atomic.Bool{}, + uuid.New(), + databases.GetDatabaseService(), + backups_services.GetBackupService(), + encryption.GetFieldEncryptor(), + restoreRepository, + backups_config.GetBackupConfigService(), + storages.GetStorageService(), + restoreNodesRegistry, + logger.GetLogger(), + usecase, + restoreDatabaseCache, + tasks_cancellation.GetTaskCancelManager(), + time.Time{}, + sync.Once{}, + atomic.Bool{}, } } func CreateTestRestoresScheduler() *RestoresScheduler { return &RestoresScheduler{ restoreRepository, - backups.GetBackupService(), + backups_services.GetBackupService(), storages.GetStorageService(), backups_config.GetBackupConfigService(), restoreNodesRegistry, diff --git a/backend/internal/features/restores/service.go b/backend/internal/features/restores/service.go index 1b99233..139eaad 100644 --- a/backend/internal/features/restores/service.go +++ b/backend/internal/features/restores/service.go @@ -3,8 +3,8 @@ package restores import ( "databasus-backend/internal/config" audit_logs "databasus-backend/internal/features/audit_logs" - "databasus-backend/internal/features/backups/backups" backups_core "databasus-backend/internal/features/backups/backups/core" + backups_services "databasus-backend/internal/features/backups/backups/services" backups_config "databasus-backend/internal/features/backups/config" "databasus-backend/internal/features/databases" "databasus-backend/internal/features/disk" @@ -26,7 +26,7 @@ import ( ) type RestoreService struct { - backupService *backups.BackupService + backupService *backups_services.BackupService restoreRepository *restores_core.RestoreRepository storageService *storages.StorageService backupConfigService *backups_config.BackupConfigService diff --git a/backend/internal/features/restores/testing.go b/backend/internal/features/restores/testing.go index d962dce..04b15a8 100644 --- a/backend/internal/features/restores/testing.go +++ b/backend/internal/features/restores/testing.go @@ -8,7 +8,7 @@ import ( "github.com/gin-gonic/gin" "github.com/google/uuid" - "databasus-backend/internal/features/backups/backups" + backups_controllers "databasus-backend/internal/features/backups/backups/controllers" backups_config "databasus-backend/internal/features/backups/config" "databasus-backend/internal/features/databases" "databasus-backend/internal/features/restores/restoring" @@ -22,12 +22,12 @@ func CreateTestRouter() *gin.Engine { workspaces_controllers.GetMembershipController(), databases.GetDatabaseController(), backups_config.GetBackupConfigController(), - backups.GetBackupController(), + backups_controllers.GetBackupController(), GetRestoreController(), ) v1 := router.Group("/api/v1") - backups.GetBackupController().RegisterPublicRoutes(v1) + backups_controllers.GetBackupController().RegisterPublicRoutes(v1) return router } diff --git a/backend/internal/features/storages/models/local/model.go b/backend/internal/features/storages/models/local/model.go index 761f979..a08da63 100644 --- a/backend/internal/features/storages/models/local/model.go +++ b/backend/internal/features/storages/models/local/model.go @@ -47,14 +47,15 @@ func (l *LocalStorage) SaveFile( logger.Info("Starting to save file to local storage", "fileName", fileName) + tempFilePath := filepath.Join(config.GetEnv().TempFolder, fileName) + err := files_utils.EnsureDirectories([]string{ config.GetEnv().TempFolder, + filepath.Dir(tempFilePath), }) if err != nil { return fmt.Errorf("failed to ensure directories: %w", err) } - - tempFilePath := filepath.Join(config.GetEnv().TempFolder, fileName) logger.Debug("Creating temp file", "fileName", fileName, "tempPath", tempFilePath) tempFile, err := os.Create(tempFilePath) @@ -101,6 +102,10 @@ func (l *LocalStorage) SaveFile( finalPath, ) + if err = files_utils.EnsureDirectories([]string{filepath.Dir(finalPath)}); err != nil { + return fmt.Errorf("failed to ensure final directory: %w", err) + } + // Move the file from temp to backups directory if err = os.Rename(tempFilePath, finalPath); err != nil { logger.Error( diff --git a/backend/internal/features/test_once_protection.go b/backend/internal/features/test_once_protection.go index 7b096f8..d88a42e 100644 --- a/backend/internal/features/test_once_protection.go +++ b/backend/internal/features/test_once_protection.go @@ -8,8 +8,8 @@ import ( "time" "databasus-backend/internal/features/audit_logs" - "databasus-backend/internal/features/backups/backups" "databasus-backend/internal/features/backups/backups/backuping" + backups_services "databasus-backend/internal/features/backups/backups/services" backups_config "databasus-backend/internal/features/backups/config" "databasus-backend/internal/features/databases" healthcheck_config "databasus-backend/internal/features/healthcheck/config" @@ -26,8 +26,8 @@ func Test_SetupDependencies_CalledTwice_LogsWarning(t *testing.T) { audit_logs.SetupDependencies() audit_logs.SetupDependencies() - backups.SetupDependencies() - backups.SetupDependencies() + backups_services.SetupDependencies() + backups_services.SetupDependencies() backups_config.SetupDependencies() backups_config.SetupDependencies() diff --git a/backend/internal/features/tests/postgresql_backup_restore_test.go b/backend/internal/features/tests/postgresql_backup_restore_test.go index 47522c8..5e78a3d 100644 --- a/backend/internal/features/tests/postgresql_backup_restore_test.go +++ b/backend/internal/features/tests/postgresql_backup_restore_test.go @@ -17,8 +17,9 @@ import ( "github.com/stretchr/testify/assert" "databasus-backend/internal/config" - "databasus-backend/internal/features/backups/backups" + backups_controllers "databasus-backend/internal/features/backups/backups/controllers" backups_core "databasus-backend/internal/features/backups/backups/core" + backups_dto "databasus-backend/internal/features/backups/backups/dto" backups_config "databasus-backend/internal/features/backups/config" "databasus-backend/internal/features/databases" pgtypes "databasus-backend/internal/features/databases/databases/postgresql" @@ -1234,7 +1235,7 @@ func createTestRouter() *gin.Engine { workspaces_controllers.GetMembershipController(), databases.GetDatabaseController(), backups_config.GetBackupConfigController(), - backups.GetBackupController(), + backups_controllers.GetBackupController(), restores.GetRestoreController(), ) return router @@ -1255,7 +1256,7 @@ func waitForBackupCompletion( t.Fatalf("Timeout waiting for backup completion after %v", timeout) } - var response backups.GetBackupsResponse + var response backups_dto.GetBackupsResponse test_utils.MakeGetRequestAndUnmarshal( t, router, @@ -1431,7 +1432,7 @@ func createBackupViaAPI( databaseID uuid.UUID, token string, ) { - request := backups.MakeBackupRequest{DatabaseID: databaseID} + request := backups_dto.MakeBackupRequest{DatabaseID: databaseID} test_utils.MakePostRequest( t, router, diff --git a/backend/internal/util/wal/calculator.go b/backend/internal/util/wal/calculator.go new file mode 100644 index 0000000..e71ded3 --- /dev/null +++ b/backend/internal/util/wal/calculator.go @@ -0,0 +1,125 @@ +package wal + +import ( + "encoding/hex" + "errors" + "fmt" + "strings" +) + +const ( + segmentNameLen = 24 + timelineLen = 8 + logLen = 8 + segLen = 8 + defaultSegmentSizeBytes = 16 * 1024 * 1024 // 16 MB +) + +// WalCalculator performs WAL segment name arithmetic for a given WAL segment size. +// +// WAL segment name format: TTTTTTTTLLLLLLLLSSSSSSSS (24 hex chars) +// - TT: timeline (8 hex digits) +// - LL: log file / XLogId (8 hex digits) +// - SS: segment within log file (8 hex digits) +// +// segmentsPerXLogId = 0x100000000 / segmentSizeBytes +// Increment SS; if SS >= segmentsPerXLogId → SS = 0, LL++ +type WalCalculator struct { + segmentSizeBytes int64 + segmentsPerXLogId uint64 +} + +// NewWalCalculator creates a WalCalculator for the given WAL segment size in bytes. +// Pass 0 or a negative value to use the PostgreSQL default of 16 MB. +func NewWalCalculator(segmentSizeBytes int64) *WalCalculator { + if segmentSizeBytes <= 0 { + segmentSizeBytes = defaultSegmentSizeBytes + } + + return &WalCalculator{ + segmentSizeBytes: segmentSizeBytes, + segmentsPerXLogId: uint64(0x100000000) / uint64(segmentSizeBytes), + } +} + +// NextSegment computes the next WAL segment name after current. +// Returns an error if current is not a valid 24-character hex WAL segment name. +func (c *WalCalculator) NextSegment(current string) (string, error) { + if !c.IsValidSegmentName(current) { + return "", fmt.Errorf("invalid WAL segment name: %q", current) + } + + timeline := current[:timelineLen] + logHex := current[timelineLen : timelineLen+logLen] + segHex := current[timelineLen+logLen:] + + logVal, err := parseHex32(logHex) + if err != nil { + return "", fmt.Errorf("parse log part of %q: %w", current, err) + } + + segVal, err := parseHex32(segHex) + if err != nil { + return "", fmt.Errorf("parse seg part of %q: %w", current, err) + } + + segVal++ + if uint64(segVal) >= c.segmentsPerXLogId { + segVal = 0 + logVal++ + } + + return fmt.Sprintf("%s%08X%08X", strings.ToUpper(timeline), logVal, segVal), nil +} + +// IsValidSegmentName returns true if name is a 24-character uppercase (or lowercase) hex string +// representing a valid WAL segment name. +func (c *WalCalculator) IsValidSegmentName(name string) bool { + if len(name) != segmentNameLen { + return false + } + + _, err := hex.DecodeString(name) + return err == nil +} + +// Compare compares two WAL segment names a and b by their numeric value. +// Returns -1 if a < b, 0 if a == b, 1 if a > b. +// Both names must be valid; returns an error otherwise. +// Timeline is compared first, then log file, then segment number. +func (c *WalCalculator) Compare(a, b string) (int, error) { + if !c.IsValidSegmentName(a) { + return 0, fmt.Errorf("invalid WAL segment name: %q", a) + } + + if !c.IsValidSegmentName(b) { + return 0, fmt.Errorf("invalid WAL segment name: %q", b) + } + + // Fixed-width uppercase hex: lexicographic order equals numeric order. + aUpper := strings.ToUpper(a) + bUpper := strings.ToUpper(b) + + if aUpper < bUpper { + return -1, nil + } + + if aUpper > bUpper { + return 1, nil + } + + return 0, nil +} + +func parseHex32(s string) (uint32, error) { + if len(s) != 8 { + return 0, errors.New("expected 8 hex characters") + } + + b, err := hex.DecodeString(s) + if err != nil { + return 0, err + } + + return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3]), nil +} diff --git a/backend/internal/util/wal/calculator_test.go b/backend/internal/util/wal/calculator_test.go new file mode 100644 index 0000000..c19a6ce --- /dev/null +++ b/backend/internal/util/wal/calculator_test.go @@ -0,0 +1,221 @@ +package wal + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + mb1 = 1 * 1024 * 1024 + mb16 = 16 * 1024 * 1024 + mb64 = 64 * 1024 * 1024 +) + +// NextSegment — no wrap + +func Test_WalCalculator_NextSegment_DefaultSegmentSize_NoWrap(t *testing.T) { + calc := NewWalCalculator(mb16) + + cases := []struct { + input string + expected string + }{ + {"000000010000000100000000", "000000010000000100000001"}, + {"000000010000000100000027", "000000010000000100000028"}, + {"0000000100000001000000AB", "0000000100000001000000AC"}, + {"0000000100000001000000FE", "0000000100000001000000FF"}, + } + + for _, tc := range cases { + result, err := calc.NextSegment(tc.input) + require.NoError(t, err) + assert.Equal(t, tc.expected, result, "input=%s", tc.input) + } +} + +// NextSegment — wrap at 0x100 (256) for 16 MB segments + +func Test_WalCalculator_NextSegment_DefaultSegmentSize_WrapsAt256(t *testing.T) { + calc := NewWalCalculator(mb16) + + // SS=0xFF → SS=0x00, LL++ + result, err := calc.NextSegment("0000000100000001000000FF") + require.NoError(t, err) + assert.Equal(t, "000000010000000200000000", result) + + result, err = calc.NextSegment("0000000200000005000000FF") + require.NoError(t, err) + assert.Equal(t, "000000020000000600000000", result) +} + +// NextSegment — wrap at 0x1000 (4096) for 1 MB segments + +func Test_WalCalculator_NextSegment_1MbSegmentSize_WrapsAt4096(t *testing.T) { + calc := NewWalCalculator(mb1) + + // segmentsPerXLogId = 0x100000000 / (1*1024*1024) = 4096 = 0x1000 + // Last valid SS = 0x00000FFF + + result, err := calc.NextSegment("000000010000000100000FFE") + require.NoError(t, err) + assert.Equal(t, "000000010000000100000FFF", result) + + // wrap: SS=0x0FFF → SS=0, LL++ + result, err = calc.NextSegment("000000010000000100000FFF") + require.NoError(t, err) + assert.Equal(t, "000000010000000200000000", result) +} + +// NextSegment — wrap at 0x40 (64) for 64 MB segments + +func Test_WalCalculator_NextSegment_64MbSegmentSize_WrapsAt64(t *testing.T) { + calc := NewWalCalculator(mb64) + + // segmentsPerXLogId = 0x100000000 / (64*1024*1024) = 64 = 0x40 + // Last valid SS = 0x0000003F + + result, err := calc.NextSegment("00000001000000010000003E") + require.NoError(t, err) + assert.Equal(t, "00000001000000010000003F", result) + + // wrap: SS=0x3F → SS=0, LL++ + result, err = calc.NextSegment("00000001000000010000003F") + require.NoError(t, err) + assert.Equal(t, "000000010000000200000000", result) +} + +// NextSegment — log file increment on wrap + +func Test_WalCalculator_NextSegment_IncrementsLog_OnSegmentWrap(t *testing.T) { + calc := NewWalCalculator(mb16) + + // LL=0x00000001, wraps to LL=0x00000002 + result, err := calc.NextSegment("0000000100000001000000FF") + require.NoError(t, err) + assert.Equal(t, "000000010000000200000000", result) + + // LL=0x0000000F, wraps to LL=0x00000010 + result, err = calc.NextSegment("00000001000000FF000000FF") + require.NoError(t, err) + assert.Equal(t, "000000010000010000000000", result) +} + +// NextSegment — timeline is preserved + +func Test_WalCalculator_NextSegment_TimelinePreserved(t *testing.T) { + calc := NewWalCalculator(mb16) + + result, err := calc.NextSegment("000000030000000100000005") + require.NoError(t, err) + assert.Equal(t, "000000030000000100000006", result) +} + +// NextSegment — invalid input + +func Test_WalCalculator_NextSegment_InvalidName_ReturnsError(t *testing.T) { + calc := NewWalCalculator(mb16) + + cases := []string{ + "", + "00000001000000010000000", // 23 chars + "0000000100000001000000001", // 25 chars + "00000001000000010000000G", // non-hex char + "short", + } + + for _, name := range cases { + _, err := calc.NextSegment(name) + assert.Error(t, err, "expected error for input %q", name) + } +} + +// IsValidSegmentName + +func Test_WalCalculator_IsValidSegmentName_ValidName_ReturnsTrue(t *testing.T) { + calc := NewWalCalculator(mb16) + + valid := []string{ + "000000010000000100000000", + "0000000100000001000000FF", + "FFFFFFFFFFFFFFFFFFFFFFFF", + "000000010000000100000027", + "0000000200000005000000AB", + } + + for _, name := range valid { + assert.True(t, calc.IsValidSegmentName(name), "expected valid for %q", name) + } +} + +func Test_WalCalculator_IsValidSegmentName_TooShort_ReturnsFalse(t *testing.T) { + calc := NewWalCalculator(mb16) + + assert.False(t, calc.IsValidSegmentName("00000001000000010000000")) // 23 chars + assert.False(t, calc.IsValidSegmentName("")) +} + +func Test_WalCalculator_IsValidSegmentName_TooLong_ReturnsFalse(t *testing.T) { + calc := NewWalCalculator(mb16) + + assert.False(t, calc.IsValidSegmentName("0000000100000001000000001")) // 25 chars +} + +func Test_WalCalculator_IsValidSegmentName_NonHex_ReturnsFalse(t *testing.T) { + calc := NewWalCalculator(mb16) + + assert.False(t, calc.IsValidSegmentName("00000001000000010000000G")) + assert.False(t, calc.IsValidSegmentName("00000001000000010000000Z")) + assert.False(t, calc.IsValidSegmentName("000000010000000100000 00")) +} + +// Compare + +func Test_WalCalculator_Compare_ReturnsCorrectOrdering(t *testing.T) { + calc := NewWalCalculator(mb16) + + cases := []struct { + a string + b string + expected int + }{ + // equal + {"000000010000000100000001", "000000010000000100000001", 0}, + // segment ordering + {"000000010000000100000001", "000000010000000100000002", -1}, + {"000000010000000100000002", "000000010000000100000001", 1}, + // log ordering + {"000000010000000100000000", "000000010000000200000000", -1}, + {"000000010000000200000000", "000000010000000100000000", 1}, + // timeline ordering + {"000000010000000100000000", "000000020000000100000000", -1}, + {"000000020000000100000000", "000000010000000100000000", 1}, + // across wrap boundary: log 1, seg 255 < log 2, seg 0 + {"0000000100000001000000FF", "000000010000000200000000", -1}, + } + + for _, tc := range cases { + result, err := calc.Compare(tc.a, tc.b) + require.NoError(t, err) + assert.Equal(t, tc.expected, result, "Compare(%s, %s)", tc.a, tc.b) + } +} + +func Test_WalCalculator_Compare_InvalidInput_ReturnsError(t *testing.T) { + calc := NewWalCalculator(mb16) + + _, err := calc.Compare("invalid", "000000010000000100000001") + assert.Error(t, err) + + _, err = calc.Compare("000000010000000100000001", "invalid") + assert.Error(t, err) +} + +// Default segment size via zero input + +func Test_WalCalculator_NewWalCalculator_ZeroSize_UsesDefault16MB(t *testing.T) { + calc := NewWalCalculator(0) + assert.Equal(t, int64(mb16), calc.segmentSizeBytes) + assert.Equal(t, uint64(256), calc.segmentsPerXLogId) +} diff --git a/backend/migrations/20260306045548_add_wal_properties.sql b/backend/migrations/20260306045548_add_wal_properties.sql new file mode 100644 index 0000000..519e367 --- /dev/null +++ b/backend/migrations/20260306045548_add_wal_properties.sql @@ -0,0 +1,72 @@ +-- +goose Up +-- +goose StatementBegin +ALTER TABLE databases + ADD COLUMN agent_token TEXT, + ADD COLUMN is_agent_token_generated BOOLEAN NOT NULL DEFAULT FALSE; + +CREATE UNIQUE INDEX idx_databases_agent_token ON databases (agent_token) WHERE agent_token IS NOT NULL; + +ALTER TABLE postgresql_databases + ADD COLUMN backup_type TEXT NOT NULL DEFAULT 'PG_DUMP'; + +ALTER TABLE postgresql_databases + ALTER COLUMN host DROP NOT NULL, + ALTER COLUMN port DROP NOT NULL, + ALTER COLUMN username DROP NOT NULL, + ALTER COLUMN password DROP NOT NULL; + +ALTER TABLE backups + ADD COLUMN pg_wal_backup_type TEXT, + ADD COLUMN pg_wal_start_segment TEXT, + ADD COLUMN pg_wal_stop_segment TEXT, + ADD COLUMN pg_version TEXT, + ADD COLUMN pg_wal_segment_name TEXT; + +CREATE INDEX idx_backups_pg_wal_segment_name + ON backups (database_id, pg_wal_segment_name) + WHERE pg_wal_segment_name IS NOT NULL; + +CREATE INDEX idx_backups_pg_wal_backup_type_created + ON backups (database_id, pg_wal_backup_type, created_at DESC) + WHERE pg_wal_backup_type IS NOT NULL; +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +DROP INDEX IF EXISTS idx_backups_pg_wal_segment_name; +DROP INDEX IF EXISTS idx_backups_pg_wal_backup_type_created; + +ALTER TABLE backups + DROP COLUMN pg_wal_backup_type, + DROP COLUMN pg_wal_start_segment, + DROP COLUMN pg_wal_stop_segment, + DROP COLUMN pg_version, + DROP COLUMN pg_wal_segment_name; + +UPDATE postgresql_databases +SET host = 'localhost' WHERE host IS NULL OR host = ''; + +UPDATE postgresql_databases +SET port = 5432 WHERE port IS NULL OR port = 0; + +UPDATE postgresql_databases +SET username = 'postgres' WHERE username IS NULL OR username = ''; + +UPDATE postgresql_databases +SET password = 'stubpassword' WHERE password IS NULL OR password = ''; + +ALTER TABLE postgresql_databases + DROP COLUMN backup_type; + +ALTER TABLE postgresql_databases + ALTER COLUMN host SET NOT NULL, + ALTER COLUMN port SET NOT NULL, + ALTER COLUMN username SET NOT NULL, + ALTER COLUMN password SET NOT NULL; + +DROP INDEX IF EXISTS idx_databases_agent_token; + +ALTER TABLE databases + DROP COLUMN agent_token, + DROP COLUMN is_agent_token_generated; +-- +goose StatementEnd