FEATURE (backups): Add WAL API

This commit is contained in:
Rostislav Dugin
2026-03-03 12:20:23 +03:00
parent 91f35a3e17
commit 230cc27ea6
49 changed files with 3941 additions and 372 deletions

1
.gitignore vendored
View File

@@ -12,3 +12,4 @@ node_modules/
.DS_Store
/scripts
.vscode/settings.json
.claude

1
CLAUDE.md Normal file
View File

@@ -0,0 +1 @@
Look at @AGENTS.md

View File

@@ -14,9 +14,10 @@ import (
"databasus-backend/internal/config"
"databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
"databasus-backend/internal/features/backups/backups/backuping"
backups_controllers "databasus-backend/internal/features/backups/backups/controllers"
backups_download "databasus-backend/internal/features/backups/backups/download"
backups_services "databasus-backend/internal/features/backups/backups/services"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/disk"
@@ -209,7 +210,9 @@ func setUpRoutes(r *gin.Engine) {
userController := users_controllers.GetUserController()
userController.RegisterRoutes(v1)
system_healthcheck.GetHealthcheckController().RegisterRoutes(v1)
backups.GetBackupController().RegisterPublicRoutes(v1)
backups_controllers.GetBackupController().RegisterPublicRoutes(v1)
backups_controllers.GetPostgresWalBackupController().RegisterRoutes(v1)
databases.GetDatabaseController().RegisterPublicRoutes(v1)
// Setup auth middleware
userService := users_services.GetUserService()
@@ -226,7 +229,7 @@ func setUpRoutes(r *gin.Engine) {
notifiers.GetNotifierController().RegisterRoutes(protected)
storages.GetStorageController().RegisterRoutes(protected)
databases.GetDatabaseController().RegisterRoutes(protected)
backups.GetBackupController().RegisterRoutes(protected)
backups_controllers.GetBackupController().RegisterRoutes(protected)
restores.GetRestoreController().RegisterRoutes(protected)
healthcheck_config.GetHealthcheckConfigController().RegisterRoutes(protected)
healthcheck_attempt.GetHealthcheckAttemptController().RegisterRoutes(protected)
@@ -238,7 +241,7 @@ func setUpRoutes(r *gin.Engine) {
func setUpDependencies() {
databases.SetupDependencies()
backups.SetupDependencies()
backups_services.SetupDependencies()
restores.SetupDependencies()
healthcheck_config.SetupDependencies()
audit_logs.SetupDependencies()

View File

@@ -15,7 +15,6 @@ import (
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
files_utils "databasus-backend/internal/util/files"
)
const (
@@ -171,13 +170,7 @@ func (s *BackupsScheduler) StartBackup(database *databases.Database, isCallNotif
timestamp := time.Now().UTC()
backup := &backups_core.Backup{
ID: backupID,
FileName: fmt.Sprintf(
"%s-%s-%s",
files_utils.SanitizeFilename(database.Name),
timestamp.Format("20060102-150405"),
backupID.String(),
),
ID: backupID,
DatabaseID: backupConfig.DatabaseID,
StorageID: *backupConfig.StorageID,
Status: backups_core.BackupStatusInProgress,
@@ -185,6 +178,8 @@ func (s *BackupsScheduler) StartBackup(database *databases.Database, isCallNotif
CreatedAt: timestamp,
}
backup.GenerateFilename(database.Name)
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error(
"Failed to save backup",

View File

@@ -1,9 +1,11 @@
package backups
package backups_controllers
import (
"context"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_download "databasus-backend/internal/features/backups/backups/download"
backups_dto "databasus-backend/internal/features/backups/backups/dto"
backups_services "databasus-backend/internal/features/backups/backups/services"
"databasus-backend/internal/features/databases"
users_middleware "databasus-backend/internal/features/users/middleware"
files_utils "databasus-backend/internal/util/files"
@@ -17,7 +19,7 @@ import (
)
type BackupController struct {
backupService *BackupService
backupService *backups_services.BackupService
}
func (c *BackupController) RegisterRoutes(router *gin.RouterGroup) {
@@ -42,7 +44,7 @@ func (c *BackupController) RegisterPublicRoutes(router *gin.RouterGroup) {
// @Param database_id query string true "Database ID"
// @Param limit query int false "Number of items per page" default(10)
// @Param offset query int false "Offset for pagination" default(0)
// @Success 200 {object} GetBackupsResponse
// @Success 200 {object} backups_dto.GetBackupsResponse
// @Failure 400
// @Failure 401
// @Failure 500
@@ -54,7 +56,7 @@ func (c *BackupController) GetBackups(ctx *gin.Context) {
return
}
var request GetBackupsRequest
var request backups_dto.GetBackupsRequest
if err := ctx.ShouldBindQuery(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -81,7 +83,7 @@ func (c *BackupController) GetBackups(ctx *gin.Context) {
// @Tags backups
// @Accept json
// @Produce json
// @Param request body MakeBackupRequest true "Backup creation data"
// @Param request body backups_dto.MakeBackupRequest true "Backup creation data"
// @Success 200 {object} map[string]string
// @Failure 400
// @Failure 401
@@ -94,7 +96,7 @@ func (c *BackupController) MakeBackup(ctx *gin.Context) {
return
}
var request MakeBackupRequest
var request backups_dto.MakeBackupRequest
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -310,10 +312,6 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
c.backupService.WriteAuditLogForDownload(downloadToken.UserID, backup, database)
}
type MakeBackupRequest struct {
DatabaseID uuid.UUID `json:"database_id" binding:"required"`
}
func (c *BackupController) generateBackupFilename(
backup *backups_core.Backup,
database *databases.Database,

View File

@@ -1,4 +1,4 @@
package backups
package backups_controllers
import (
"context"
@@ -24,11 +24,14 @@ import (
backups_common "databasus-backend/internal/features/backups/backups/common"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_download "databasus-backend/internal/features/backups/backups/download"
backups_dto "databasus-backend/internal/features/backups/backups/dto"
backups_services "databasus-backend/internal/features/backups/backups/services"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/postgresql"
"databasus-backend/internal/features/storages"
local_storage "databasus-backend/internal/features/storages/models/local"
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
users_dto "databasus-backend/internal/features/users/dto"
users_enums "databasus-backend/internal/features/users/enums"
users_services "databasus-backend/internal/features/users/services"
@@ -119,7 +122,7 @@ func Test_GetBackups_PermissionsEnforced(t *testing.T) {
)
if tt.expectSuccess {
var response GetBackupsResponse
var response backups_dto.GetBackupsResponse
err := json.Unmarshal(testResp.Body, &response)
assert.NoError(t, err)
assert.GreaterOrEqual(t, len(response.Backups), 1)
@@ -214,7 +217,7 @@ func Test_CreateBackup_PermissionsEnforced(t *testing.T) {
testUserToken = nonMember.Token
}
request := MakeBackupRequest{DatabaseID: database.ID}
request := backups_dto.MakeBackupRequest{DatabaseID: database.ID}
testResp := test_utils.MakePostRequest(
t,
router,
@@ -245,7 +248,7 @@ func Test_CreateBackup_AuditLogWritten(t *testing.T) {
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
enableBackupForDatabase(database.ID)
request := MakeBackupRequest{DatabaseID: database.ID}
request := backups_dto.MakeBackupRequest{DatabaseID: database.ID}
test_utils.MakePostRequest(
t,
router,
@@ -373,7 +376,7 @@ func Test_DeleteBackup_PermissionsEnforced(t *testing.T) {
ownerUser, err := userService.GetUserFromToken(owner.Token)
assert.NoError(t, err)
response, err := GetBackupService().GetBackups(ownerUser, database.ID, 10, 0)
response, err := backups_services.GetBackupService().GetBackups(ownerUser, database.ID, 10, 0)
assert.NoError(t, err)
assert.Equal(t, 0, len(response.Backups))
}
@@ -999,7 +1002,7 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
assert.NoError(t, err)
// Register a cancellable context for the backup
GetBackupService().taskCancelManager.RegisterTask(backup.ID, func() {})
task_cancellation.GetTaskCancelManager().RegisterTask(backup.ID, func() {})
resp := test_utils.MakePostRequest(
t,
@@ -1091,7 +1094,7 @@ func Test_ConcurrentDownloadPrevention(t *testing.T) {
time.Sleep(50 * time.Millisecond)
service := GetBackupService()
service := backups_services.GetBackupService()
if !service.IsDownloadInProgress(owner.UserID) {
t.Log("Warning: First download completed before we could test concurrency")
<-downloadComplete
@@ -1192,7 +1195,7 @@ func Test_GenerateDownloadToken_BlockedWhenDownloadInProgress(t *testing.T) {
time.Sleep(50 * time.Millisecond)
service := GetBackupService()
service := backups_services.GetBackupService()
if !service.IsDownloadInProgress(owner.UserID) {
t.Log("Warning: First download completed before we could test token generation blocking")
<-downloadComplete
@@ -1268,7 +1271,7 @@ func Test_MakeBackup_VerifyBackupAndMetadataFilesExistInStorage(t *testing.T) {
initialBackups, err := backupRepo.FindByDatabaseID(database.ID)
assert.NoError(t, err)
request := MakeBackupRequest{DatabaseID: database.ID}
request := backups_dto.MakeBackupRequest{DatabaseID: database.ID}
test_utils.MakePostRequest(
t,
router,
@@ -1502,7 +1505,7 @@ func createTestBackup(
}
func createExpiredDownloadToken(backupID, userID uuid.UUID) string {
tokenService := GetBackupService().downloadTokenService
tokenService := backups_download.GetDownloadTokenService()
token, err := tokenService.Generate(backupID, userID)
if err != nil {
panic(fmt.Sprintf("Failed to generate download token: %v", err))
@@ -1843,7 +1846,7 @@ func Test_DeleteBackup_RemovesBackupAndMetadataFilesFromDisk(t *testing.T) {
initialBackups, err := backupRepo.FindByDatabaseID(database.ID)
assert.NoError(t, err)
request := MakeBackupRequest{DatabaseID: database.ID}
request := backups_dto.MakeBackupRequest{DatabaseID: database.ID}
test_utils.MakePostRequest(
t,
router,

View File

@@ -0,0 +1,23 @@
package backups_controllers
import (
backups_services "databasus-backend/internal/features/backups/backups/services"
"databasus-backend/internal/features/databases"
)
var backupController = &BackupController{
backups_services.GetBackupService(),
}
func GetBackupController() *BackupController {
return backupController
}
var postgresWalBackupController = &PostgreWalBackupController{
databases.GetDatabaseService(),
backups_services.GetWalService(),
}
func GetPostgresWalBackupController() *PostgreWalBackupController {
return postgresWalBackupController
}

View File

@@ -0,0 +1,291 @@
package backups_controllers
import (
"io"
"net/http"
"strconv"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_dto "databasus-backend/internal/features/backups/backups/dto"
backups_services "databasus-backend/internal/features/backups/backups/services"
"databasus-backend/internal/features/databases"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
// PostgreWalBackupController handles WAL backup endpoints used by the databasus-cli agent.
// Authentication is via a plain agent token in the Authorization header (no Bearer prefix).
type PostgreWalBackupController struct {
databaseService *databases.DatabaseService
walService *backups_services.PostgreWalBackupService
}
func (c *PostgreWalBackupController) RegisterRoutes(router *gin.RouterGroup) {
walRoutes := router.Group("/backups/postgres/wal")
walRoutes.GET("/next-full-backup-time", c.GetNextFullBackupTime)
walRoutes.POST("/error", c.ReportError)
walRoutes.POST("/upload", c.Upload)
walRoutes.GET("/restore/plan", c.GetRestorePlan)
walRoutes.GET("/restore/download", c.DownloadBackupFile)
}
// GetNextFullBackupTime
// @Summary Get next full backup time
// @Description Returns the next scheduled full basebackup time for the authenticated database
// @Tags backups-wal
// @Produce json
// @Security AgentToken
// @Success 200 {object} backups_dto.GetNextFullBackupTimeResponse
// @Failure 401 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /backups/postgres/wal/next-full-backup-time [get]
func (c *PostgreWalBackupController) GetNextFullBackupTime(ctx *gin.Context) {
database, err := c.getDatabase(ctx)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
return
}
response, err := c.walService.GetNextFullBackupTime(database)
if err != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, response)
}
// ReportError
// @Summary Report agent error
// @Description Records a fatal error from the agent against the database record and marks it as errored
// @Tags backups-wal
// @Accept json
// @Security AgentToken
// @Param request body backups_dto.ReportErrorRequest true "Error details"
// @Success 200
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /backups/postgres/wal/error [post]
func (c *PostgreWalBackupController) ReportError(ctx *gin.Context) {
database, err := c.getDatabase(ctx)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
return
}
var request backups_dto.ReportErrorRequest
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := c.walService.ReportError(database, request.Error); err != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
ctx.Status(http.StatusOK)
}
// Upload
// @Summary Stream upload a basebackup or WAL segment
// @Description Accepts a zstd-compressed binary stream and stores it in the database's configured storage.
// The server generates the storage filename; agents do not control the destination path.
// For WAL segment uploads the server validates the WAL chain and returns 409 if a gap is detected
// or 400 if no full backup exists yet (agent should trigger a full basebackup in both cases).
// @Tags backups-wal
// @Accept application/octet-stream
// @Produce json
// @Security AgentToken
// @Param X-Upload-Type header string true "Upload type" Enums(basebackup, wal)
// @Param X-Wal-Segment-Name header string false "24-hex WAL segment identifier (required for wal uploads, e.g. 0000000100000001000000AB)"
// @Param X-Wal-Segment-Size header int false "WAL segment size in bytes reported by the PostgreSQL instance (default: 16777216)"
// @Param fullBackupWalStartSegment query string false "First WAL segment needed to make the basebackup consistent (required for basebackup uploads)"
// @Param fullBackupWalStopSegment query string false "Last WAL segment included in the basebackup (required for basebackup uploads)"
// @Success 204
// @Failure 400 {object} backups_dto.UploadGapResponse "No full backup exists (error: no_full_backup)"
// @Failure 401 {object} map[string]string
// @Failure 409 {object} backups_dto.UploadGapResponse "WAL chain gap detected (error: gap_detected)"
// @Failure 500 {object} map[string]string
// @Router /backups/postgres/wal/upload [post]
func (c *PostgreWalBackupController) Upload(ctx *gin.Context) {
database, err := c.getDatabase(ctx)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
return
}
uploadType := backups_core.PgWalUploadType(ctx.GetHeader("X-Upload-Type"))
if uploadType != backups_core.PgWalUploadTypeBasebackup &&
uploadType != backups_core.PgWalUploadTypeWal {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "X-Upload-Type must be 'basebackup' or 'wal'"},
)
return
}
walSegmentName := ""
if uploadType == backups_core.PgWalUploadTypeWal {
walSegmentName = ctx.GetHeader("X-Wal-Segment-Name")
if walSegmentName == "" {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "X-Wal-Segment-Name is required for wal uploads"},
)
return
}
}
if uploadType == backups_core.PgWalUploadTypeBasebackup {
if ctx.Query("fullBackupWalStartSegment") == "" ||
ctx.Query("fullBackupWalStopSegment") == "" {
ctx.JSON(
http.StatusBadRequest,
gin.H{
"error": "fullBackupWalStartSegment and fullBackupWalStopSegment are required for basebackup uploads",
},
)
return
}
}
walSegmentSizeBytes := int64(0)
if raw := ctx.GetHeader("X-Wal-Segment-Size"); raw != "" {
parsed, parseErr := strconv.ParseInt(raw, 10, 64)
if parseErr != nil || parsed <= 0 {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "X-Wal-Segment-Size must be a positive integer"},
)
return
}
walSegmentSizeBytes = parsed
}
gapResp, uploadErr := c.walService.UploadWal(
ctx.Request.Context(),
database,
uploadType,
walSegmentName,
ctx.Query("fullBackupWalStartSegment"),
ctx.Query("fullBackupWalStopSegment"),
walSegmentSizeBytes,
ctx.Request.Body,
)
if uploadErr != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": uploadErr.Error()})
return
}
if gapResp != nil {
if gapResp.Error == "no_full_backup" {
ctx.JSON(http.StatusBadRequest, gapResp)
return
}
ctx.JSON(http.StatusConflict, gapResp)
return
}
ctx.Status(http.StatusNoContent)
}
// GetRestorePlan
// @Summary Get restore plan
// @Description Resolves the full backup and all required WAL segments needed for recovery. Validates the WAL chain is continuous.
// @Tags backups-wal
// @Produce json
// @Security AgentToken
// @Param backupId query string false "UUID of a specific full backup to restore from; defaults to the most recent"
// @Success 200 {object} backups_dto.GetRestorePlanResponse
// @Failure 400 {object} map[string]string "Broken WAL chain or no backups available"
// @Failure 401 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /backups/postgres/wal/restore/plan [get]
func (c *PostgreWalBackupController) GetRestorePlan(ctx *gin.Context) {
database, err := c.getDatabase(ctx)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
return
}
var backupID *uuid.UUID
if raw := ctx.Query("backupId"); raw != "" {
parsed, parseErr := uuid.Parse(raw)
if parseErr != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backupId format"})
return
}
backupID = &parsed
}
response, planErr, err := c.walService.GetRestorePlan(database, backupID)
if err != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if planErr != nil {
ctx.JSON(http.StatusBadRequest, planErr)
return
}
ctx.JSON(http.StatusOK, response)
}
// DownloadBackupFile
// @Summary Download a backup or WAL segment file for restore
// @Description Retrieves the backup file by ID (validated against the authenticated database), decrypts it server-side if encrypted, and streams the zstd-compressed result to the agent
// @Tags backups-wal
// @Produce application/octet-stream
// @Security AgentToken
// @Param backupId query string true "Backup ID from the restore plan response"
// @Success 200 {file} file
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Router /backups/postgres/wal/restore/download [get]
func (c *PostgreWalBackupController) DownloadBackupFile(ctx *gin.Context) {
database, err := c.getDatabase(ctx)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
return
}
backupIDRaw := ctx.Query("backupId")
if backupIDRaw == "" {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "backupId is required"})
return
}
backupID, err := uuid.Parse(backupIDRaw)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backupId format"})
return
}
reader, err := c.walService.DownloadBackupFile(database, backupID)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
defer func() { _ = reader.Close() }()
ctx.Header("Content-Type", "application/octet-stream")
ctx.Status(http.StatusOK)
_, _ = io.Copy(ctx.Writer, reader)
}
func (c *PostgreWalBackupController) getDatabase(
ctx *gin.Context,
) (*databases.Database, error) {
token := ctx.GetHeader("Authorization")
return c.databaseService.GetDatabaseByAgentToken(token)
}

View File

@@ -1,4 +1,4 @@
package backups
package backups_controllers
import (
"testing"
@@ -41,7 +41,7 @@ func WaitForBackupCompletion(
deadline := time.Now().UTC().Add(timeout)
for time.Now().UTC().Before(deadline) {
backups, err := backupRepository.FindByDatabaseID(databaseID)
backups, err := backups_core.GetBackupRepository().FindByDatabaseID(databaseID)
if err != nil {
t.Logf("WaitForBackupCompletion: error finding backups: %v", err)
time.Sleep(50 * time.Millisecond)

View File

@@ -0,0 +1,7 @@
package backups_core
var backupRepository = &BackupRepository{}
func GetBackupRepository() *BackupRepository {
return backupRepository
}

View File

@@ -8,3 +8,10 @@ const (
BackupStatusFailed BackupStatus = "FAILED"
BackupStatusCanceled BackupStatus = "CANCELED"
)
type PgWalUploadType string
const (
PgWalUploadTypeBasebackup PgWalUploadType = "basebackup"
PgWalUploadTypeWal PgWalUploadType = "wal"
)

View File

@@ -1,12 +1,22 @@
package backups_core
import (
backups_config "databasus-backend/internal/features/backups/config"
"fmt"
"time"
backups_config "databasus-backend/internal/features/backups/config"
files_utils "databasus-backend/internal/util/files"
"github.com/google/uuid"
)
type PgWalBackupType string
const (
PgWalBackupTypeFullBackup PgWalBackupType = "PG_FULL_BACKUP"
PgWalBackupTypeWalSegment PgWalBackupType = "PG_WAL_SEGMENT"
)
type Backup struct {
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
FileName string `json:"fileName" gorm:"column:file_name;type:text;not null"`
@@ -26,5 +36,23 @@ type Backup struct {
EncryptionIV *string `json:"-" gorm:"column:encryption_iv"`
Encryption backups_config.BackupEncryption `json:"encryption" gorm:"column:encryption;type:text;not null;default:'NONE'"`
// Postgres WAL backup specific fields
PgWalBackupType *PgWalBackupType `json:"pgWalBackupType" gorm:"column:pg_wal_backup_type;type:text"`
PgFullBackupWalStartSegmentName *string `json:"pgFullBackupWalStartSegmentName" gorm:"column:pg_wal_start_segment;type:text"`
PgFullBackupWalStopSegmentName *string `json:"pgFullBackupWalStopSegmentName" gorm:"column:pg_wal_stop_segment;type:text"`
PgVersion *string `json:"pgVersion" gorm:"column:pg_version;type:text"`
PgWalSegmentName *string `json:"pgWalSegmentName" gorm:"column:pg_wal_segment_name;type:text"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
}
func (b *Backup) GenerateFilename(dbName string) {
timestamp := time.Now().UTC()
b.FileName = fmt.Sprintf(
"%s-%s-%s",
files_utils.SanitizeFilename(dbName),
timestamp.Format("20060102-150405"),
b.ID.String(),
)
}

View File

@@ -245,3 +245,134 @@ func (r *BackupRepository) FindOldestByDatabaseExcludingInProgress(
return backups, nil
}
func (r *BackupRepository) FindCompletedFullWalBackupByID(
databaseID uuid.UUID,
backupID uuid.UUID,
) (*Backup, error) {
var backup Backup
err := storage.
GetDb().
Where(
"database_id = ? AND id = ? AND pg_wal_backup_type = ? AND status = ?",
databaseID,
backupID,
PgWalBackupTypeFullBackup,
BackupStatusCompleted,
).
First(&backup).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &backup, nil
}
func (r *BackupRepository) FindCompletedWalSegmentsAfter(
databaseID uuid.UUID,
afterSegmentName string,
) ([]*Backup, error) {
var backups []*Backup
err := storage.
GetDb().
Where(
"database_id = ? AND pg_wal_backup_type = ? AND pg_wal_segment_name >= ? AND status = ?",
databaseID,
PgWalBackupTypeWalSegment,
afterSegmentName,
BackupStatusCompleted,
).
Order("pg_wal_segment_name ASC").
Find(&backups).Error
if err != nil {
return nil, err
}
return backups, nil
}
func (r *BackupRepository) FindLastCompletedFullWalBackupByDatabaseID(
databaseID uuid.UUID,
) (*Backup, error) {
var backup Backup
err := storage.
GetDb().
Where(
"database_id = ? AND pg_wal_backup_type = ? AND status = ?",
databaseID,
PgWalBackupTypeFullBackup,
BackupStatusCompleted,
).
Order("created_at DESC").
First(&backup).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &backup, nil
}
func (r *BackupRepository) FindWalSegmentByName(
databaseID uuid.UUID,
segmentName string,
) (*Backup, error) {
var backup Backup
err := storage.
GetDb().
Where(
"database_id = ? AND pg_wal_backup_type = ? AND pg_wal_segment_name = ?",
databaseID,
PgWalBackupTypeWalSegment,
segmentName,
).
First(&backup).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &backup, nil
}
func (r *BackupRepository) FindLastWalSegmentAfter(
databaseID uuid.UUID,
afterSegmentName string,
) (*Backup, error) {
var backup Backup
err := storage.
GetDb().
Where(
"database_id = ? AND pg_wal_backup_type = ? AND pg_wal_segment_name > ? AND status = ?",
databaseID,
PgWalBackupTypeWalSegment,
afterSegmentName,
BackupStatusCompleted,
).
Order("pg_wal_segment_name DESC").
First(&backup).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &backup, nil
}

View File

@@ -1,29 +0,0 @@
package backups
import (
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/encryption"
"io"
)
type GetBackupsRequest struct {
DatabaseID string `form:"database_id" binding:"required"`
Limit int `form:"limit"`
Offset int `form:"offset"`
}
type GetBackupsResponse struct {
Backups []*backups_core.Backup `json:"backups"`
Total int64 `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
}
type DecryptionReaderCloser struct {
*encryption.DecryptionReader
BaseReader io.ReadCloser
}
func (r *DecryptionReaderCloser) Close() error {
return r.BaseReader.Close()
}

View File

@@ -0,0 +1,78 @@
package backups_dto
import (
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/encryption"
"io"
"time"
"github.com/google/uuid"
)
type GetBackupsRequest struct {
DatabaseID string `form:"database_id" binding:"required"`
Limit int `form:"limit"`
Offset int `form:"offset"`
}
type GetBackupsResponse struct {
Backups []*backups_core.Backup `json:"backups"`
Total int64 `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
}
type DecryptionReaderCloser struct {
*encryption.DecryptionReader
BaseReader io.ReadCloser
}
func (r *DecryptionReaderCloser) Close() error {
return r.BaseReader.Close()
}
type MakeBackupRequest struct {
DatabaseID uuid.UUID `json:"database_id" binding:"required"`
}
type GetNextFullBackupTimeResponse struct {
NextFullBackupTime *time.Time `json:"nextFullBackupTime"`
}
type ReportErrorRequest struct {
Error string `json:"error" binding:"required"`
}
type UploadGapResponse struct {
Error string `json:"error"`
ExpectedSegmentName string `json:"expectedSegmentName"`
ReceivedSegmentName string `json:"receivedSegmentName"`
}
type RestorePlanFullBackup struct {
BackupID uuid.UUID `json:"id"`
FullBackupWalStartSegment string `json:"fullBackupWalStartSegment"`
FullBackupWalStopSegment string `json:"fullBackupWalStopSegment"`
PgVersion string `json:"pgVersion"`
CreatedAt time.Time `json:"createdAt"`
SizeBytes int64 `json:"sizeBytes"`
}
type RestorePlanWalSegment struct {
BackupID uuid.UUID `json:"backupId"`
SegmentName string `json:"segmentName"`
SizeBytes int64 `json:"sizeBytes"`
}
type GetRestorePlanErrorResponse struct {
Error string `json:"error"`
Message string `json:"message"`
LastContiguousSegment string `json:"lastContiguousSegment,omitempty"`
}
type GetRestorePlanResponse struct {
FullBackup RestorePlanFullBackup `json:"fullBackup"`
WalSegments []RestorePlanWalSegment `json:"walSegments"`
TotalSizeBytes int64 `json:"totalSizeBytes"`
LatestAvailableSegment string `json:"latestAvailableSegment"`
}

View File

@@ -0,0 +1,45 @@
package encryption
import (
"encoding/base64"
"fmt"
"io"
"github.com/google/uuid"
)
// EncryptionSetup holds the result of setting up encryption for a backup stream.
type EncryptionSetup struct {
Writer *EncryptionWriter
SaltBase64 string
NonceBase64 string
}
// SetupEncryptionWriter generates salt/nonce, creates an EncryptionWriter, and
// returns the base64-encoded salt and nonce for storage on the backup record.
func SetupEncryptionWriter(
baseWriter io.Writer,
masterKey string,
backupID uuid.UUID,
) (*EncryptionSetup, error) {
salt, err := GenerateSalt()
if err != nil {
return nil, fmt.Errorf("failed to generate salt: %w", err)
}
nonce, err := GenerateNonce()
if err != nil {
return nil, fmt.Errorf("failed to generate nonce: %w", err)
}
encWriter, err := NewEncryptionWriter(baseWriter, masterKey, backupID, salt, nonce)
if err != nil {
return nil, fmt.Errorf("failed to create encryption writer: %w", err)
}
return &EncryptionSetup{
Writer: encWriter,
SaltBase64: base64.StdEncoding.EncodeToString(salt),
NonceBase64: base64.StdEncoding.EncodeToString(nonce),
}, nil
}

View File

@@ -1,9 +1,6 @@
package backups
package backups_services
import (
"sync"
"sync/atomic"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups/backuping"
backups_core "databasus-backend/internal/features/backups/backups/core"
@@ -18,16 +15,16 @@ import (
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
"sync"
"sync/atomic"
)
var backupRepository = &backups_core.BackupRepository{}
var taskCancelManager = task_cancellation.GetTaskCancelManager()
var backupService = &BackupService{
databases.GetDatabaseService(),
storages.GetStorageService(),
backupRepository,
backups_core.GetBackupRepository(),
notifiers.GetNotifierService(),
notifiers.GetNotifierService(),
backups_config.GetBackupConfigService(),
@@ -44,16 +41,21 @@ var backupService = &BackupService{
backuping.GetBackupCleaner(),
}
var backupController = &BackupController{
backupService: backupService,
}
func GetBackupService() *BackupService {
return backupService
}
func GetBackupController() *BackupController {
return backupController
var walService = &PostgreWalBackupService{
backups_config.GetBackupConfigService(),
backups_core.GetBackupRepository(),
encryption.GetFieldEncryptor(),
encryption_secrets.GetSecretKeyService(),
logger.GetLogger(),
backupService,
}
func GetWalService() *PostgreWalBackupService {
return walService
}
var (

View File

@@ -0,0 +1,613 @@
package backups_services
import (
"context"
"fmt"
"io"
"log/slog"
"time"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_dto "databasus-backend/internal/features/backups/backups/dto"
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/postgresql"
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
util_encryption "databasus-backend/internal/util/encryption"
util_wal "databasus-backend/internal/util/wal"
"github.com/google/uuid"
)
// PostgreWalBackupService handles WAL segment and basebackup uploads from the databasus-cli agent.
type PostgreWalBackupService struct {
backupConfigService *backups_config.BackupConfigService
backupRepository *backups_core.BackupRepository
fieldEncryptor util_encryption.FieldEncryptor
secretKeyService *encryption_secrets.SecretKeyService
logger *slog.Logger
backupService *BackupService
}
// UploadWal accepts a streaming WAL segment or basebackup upload from the agent.
// For WAL segments it validates the WAL chain before accepting. Returns an UploadGapResponse
// (409) when the chain is broken so the agent knows to trigger a full basebackup.
func (s *PostgreWalBackupService) UploadWal(
ctx context.Context,
database *databases.Database,
uploadType backups_core.PgWalUploadType,
walSegmentName string,
fullBackupWalStartSegment string,
fullBackupWalStopSegment string,
walSegmentSizeBytes int64,
body io.Reader,
) (*backups_dto.UploadGapResponse, error) {
if err := s.validateWalBackupType(database); err != nil {
return nil, err
}
if uploadType == backups_core.PgWalUploadTypeBasebackup {
if fullBackupWalStartSegment == "" || fullBackupWalStopSegment == "" {
return nil, fmt.Errorf(
"fullBackupWalStartSegment and fullBackupWalStopSegment are required for basebackup uploads",
)
}
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
if err != nil {
return nil, fmt.Errorf("failed to get backup config: %w", err)
}
if backupConfig.Storage == nil {
return nil, fmt.Errorf("no storage configured for database %s", database.ID)
}
if uploadType == backups_core.PgWalUploadTypeWal {
// Idempotency: check before chain validation so a successful re-upload is
// not misidentified as a gap.
existing, err := s.backupRepository.FindWalSegmentByName(database.ID, walSegmentName)
if err != nil {
return nil, fmt.Errorf("failed to check for duplicate WAL segment: %w", err)
}
if existing != nil {
return nil, nil
}
gapResp, err := s.validateWalChain(database.ID, walSegmentName, walSegmentSizeBytes)
if err != nil {
return nil, err
}
if gapResp != nil {
return gapResp, nil
}
}
backup := s.createBackupRecord(
database.ID,
backupConfig.Storage.ID,
uploadType,
database.Name,
walSegmentName,
fullBackupWalStartSegment,
fullBackupWalStopSegment,
backupConfig.Encryption,
)
if err := s.backupRepository.Save(backup); err != nil {
return nil, fmt.Errorf("failed to create backup record: %w", err)
}
sizeBytes, streamErr := s.streamToStorage(ctx, backup, backupConfig, body)
if streamErr != nil {
errMsg := streamErr.Error()
s.markFailed(backup, errMsg)
return nil, fmt.Errorf("upload failed: %w", streamErr)
}
s.markCompleted(backup, sizeBytes)
return nil, nil
}
func (s *PostgreWalBackupService) GetRestorePlan(
database *databases.Database,
backupID *uuid.UUID,
) (*backups_dto.GetRestorePlanResponse, *backups_dto.GetRestorePlanErrorResponse, error) {
if err := s.validateWalBackupType(database); err != nil {
return nil, nil, err
}
fullBackup, err := s.resolveFullBackup(database.ID, backupID)
if err != nil {
return nil, nil, err
}
if fullBackup == nil {
msg := "no full backups available for this database"
if backupID != nil {
msg = fmt.Sprintf("full backup %s not found or not completed", backupID)
}
return nil, &backups_dto.GetRestorePlanErrorResponse{
Error: "no_backups",
Message: msg,
}, nil
}
startSegment := ""
if fullBackup.PgFullBackupWalStartSegmentName != nil {
startSegment = *fullBackup.PgFullBackupWalStartSegmentName
}
walSegments, err := s.backupRepository.FindCompletedWalSegmentsAfter(database.ID, startSegment)
if err != nil {
return nil, nil, fmt.Errorf("failed to query WAL segments: %w", err)
}
chainErr := s.validateRestoreWalChain(fullBackup, walSegments)
if chainErr != nil {
return nil, chainErr, nil
}
fullBackupSizeBytes := int64(fullBackup.BackupSizeMb * 1024 * 1024)
pgVersion := ""
if fullBackup.PgVersion != nil {
pgVersion = *fullBackup.PgVersion
}
stopSegment := ""
if fullBackup.PgFullBackupWalStopSegmentName != nil {
stopSegment = *fullBackup.PgFullBackupWalStopSegmentName
}
response := &backups_dto.GetRestorePlanResponse{
FullBackup: backups_dto.RestorePlanFullBackup{
BackupID: fullBackup.ID,
FullBackupWalStartSegment: startSegment,
FullBackupWalStopSegment: stopSegment,
PgVersion: pgVersion,
CreatedAt: fullBackup.CreatedAt,
SizeBytes: fullBackupSizeBytes,
},
TotalSizeBytes: fullBackupSizeBytes,
}
for _, seg := range walSegments {
segName := ""
if seg.PgWalSegmentName != nil {
segName = *seg.PgWalSegmentName
}
segSizeBytes := int64(seg.BackupSizeMb * 1024 * 1024)
response.WalSegments = append(response.WalSegments, backups_dto.RestorePlanWalSegment{
BackupID: seg.ID,
SegmentName: segName,
SizeBytes: segSizeBytes,
})
response.TotalSizeBytes += segSizeBytes
response.LatestAvailableSegment = segName
}
return response, nil, nil
}
// DownloadBackupFile returns a reader for a backup file belonging to the given database.
// Decryption is handled transparently if the backup is encrypted.
func (s *PostgreWalBackupService) DownloadBackupFile(
database *databases.Database,
backupID uuid.UUID,
) (io.ReadCloser, error) {
if err := s.validateWalBackupType(database); err != nil {
return nil, err
}
backup, err := s.backupRepository.FindByID(backupID)
if err != nil {
return nil, fmt.Errorf("backup not found: %w", err)
}
if backup.DatabaseID != database.ID {
return nil, fmt.Errorf("backup does not belong to this database")
}
if backup.Status != backups_core.BackupStatusCompleted {
return nil, fmt.Errorf("backup is not completed")
}
return s.backupService.GetBackupReader(backupID)
}
func (s *PostgreWalBackupService) validateWalChain(
databaseID uuid.UUID,
incomingSegment string,
walSegmentSizeBytes int64,
) (*backups_dto.UploadGapResponse, error) {
fullBackup, err := s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(databaseID)
if err != nil {
return nil, fmt.Errorf("failed to query full backup: %w", err)
}
// No full backup exists yet: cannot accept WAL segments without a chain anchor.
if fullBackup == nil || fullBackup.PgFullBackupWalStopSegmentName == nil {
return &backups_dto.UploadGapResponse{
Error: "no_full_backup",
ExpectedSegmentName: "",
ReceivedSegmentName: incomingSegment,
}, nil
}
stopSegment := *fullBackup.PgFullBackupWalStopSegmentName
lastWal, err := s.backupRepository.FindLastWalSegmentAfter(databaseID, stopSegment)
if err != nil {
return nil, fmt.Errorf("failed to query last WAL segment: %w", err)
}
walCalculator := util_wal.NewWalCalculator(walSegmentSizeBytes)
var chainTail string
if lastWal != nil && lastWal.PgWalSegmentName != nil {
chainTail = *lastWal.PgWalSegmentName
} else {
chainTail = stopSegment
}
expectedNext, err := walCalculator.NextSegment(chainTail)
if err != nil {
return nil, fmt.Errorf("WAL arithmetic failed for %q: %w", chainTail, err)
}
if incomingSegment != expectedNext {
return &backups_dto.UploadGapResponse{
Error: "gap_detected",
ExpectedSegmentName: expectedNext,
ReceivedSegmentName: incomingSegment,
}, nil
}
return nil, nil
}
func (s *PostgreWalBackupService) createBackupRecord(
databaseID uuid.UUID,
storageID uuid.UUID,
uploadType backups_core.PgWalUploadType,
dbName string,
walSegmentName string,
fullBackupWalStartSegment string,
fullBackupWalStopSegment string,
encryption backups_config.BackupEncryption,
) *backups_core.Backup {
now := time.Now().UTC()
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: databaseID,
StorageID: storageID,
Status: backups_core.BackupStatusInProgress,
Encryption: encryption,
CreatedAt: now,
}
backup.GenerateFilename(dbName)
if uploadType == backups_core.PgWalUploadTypeBasebackup {
walBackupType := backups_core.PgWalBackupTypeFullBackup
backup.PgWalBackupType = &walBackupType
if fullBackupWalStartSegment != "" {
backup.PgFullBackupWalStartSegmentName = &fullBackupWalStartSegment
}
if fullBackupWalStopSegment != "" {
backup.PgFullBackupWalStopSegmentName = &fullBackupWalStopSegment
}
} else {
walBackupType := backups_core.PgWalBackupTypeWalSegment
backup.PgWalBackupType = &walBackupType
backup.PgWalSegmentName = &walSegmentName
}
return backup
}
func (s *PostgreWalBackupService) streamToStorage(
ctx context.Context,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
body io.Reader,
) (int64, error) {
if backupConfig.Encryption == backups_config.BackupEncryptionEncrypted {
return s.streamEncrypted(ctx, backup, backupConfig, body, backup.FileName)
}
return s.streamDirect(ctx, backupConfig, body, backup.FileName)
}
func (s *PostgreWalBackupService) streamDirect(
ctx context.Context,
backupConfig *backups_config.BackupConfig,
body io.Reader,
fileName string,
) (int64, error) {
cr := &countingReader{r: body}
if err := backupConfig.Storage.SaveFile(ctx, s.fieldEncryptor, s.logger, fileName, cr); err != nil {
return 0, err
}
return cr.n, nil
}
func (s *PostgreWalBackupService) streamEncrypted(
ctx context.Context,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
body io.Reader,
fileName string,
) (int64, error) {
masterKey, err := s.secretKeyService.GetSecretKey()
if err != nil {
return 0, fmt.Errorf("failed to get master encryption key: %w", err)
}
pipeReader, pipeWriter := io.Pipe()
encryptionSetup, err := backup_encryption.SetupEncryptionWriter(
pipeWriter,
masterKey,
backup.ID,
)
if err != nil {
_ = pipeWriter.Close()
return 0, err
}
copyErrCh := make(chan error, 1)
go func() {
_, copyErr := io.Copy(encryptionSetup.Writer, body)
if copyErr != nil {
_ = encryptionSetup.Writer.Close()
_ = pipeWriter.CloseWithError(copyErr)
copyErrCh <- copyErr
return
}
if closeErr := encryptionSetup.Writer.Close(); closeErr != nil {
_ = pipeWriter.CloseWithError(closeErr)
copyErrCh <- closeErr
return
}
copyErrCh <- pipeWriter.Close()
}()
cr := &countingReader{r: pipeReader}
saveErr := backupConfig.Storage.SaveFile(ctx, s.fieldEncryptor, s.logger, fileName, cr)
copyErr := <-copyErrCh
if copyErr != nil {
return 0, copyErr
}
if saveErr != nil {
return 0, saveErr
}
backup.EncryptionSalt = &encryptionSetup.SaltBase64
backup.EncryptionIV = &encryptionSetup.NonceBase64
return cr.n, nil
}
func (s *PostgreWalBackupService) markCompleted(backup *backups_core.Backup, sizeBytes int64) {
backup.Status = backups_core.BackupStatusCompleted
backup.BackupSizeMb = float64(sizeBytes) / (1024 * 1024)
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error(
"failed to mark WAL backup as completed",
"backupId",
backup.ID,
"error",
err,
)
}
}
func (s *PostgreWalBackupService) markFailed(backup *backups_core.Backup, errMsg string) {
backup.Status = backups_core.BackupStatusFailed
backup.FailMessage = &errMsg
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error("failed to mark WAL backup as failed", "backupId", backup.ID, "error", err)
}
}
func (s *PostgreWalBackupService) GetNextFullBackupTime(
database *databases.Database,
) (*backups_dto.GetNextFullBackupTimeResponse, error) {
if err := s.validateWalBackupType(database); err != nil {
return nil, err
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
if err != nil {
return nil, fmt.Errorf("failed to get backup config: %w", err)
}
if backupConfig.BackupInterval == nil {
return nil, fmt.Errorf("no backup interval configured for database %s", database.ID)
}
lastFullBackup, err := s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(
database.ID,
)
if err != nil {
return nil, fmt.Errorf("failed to query last full backup: %w", err)
}
var lastBackupTime *time.Time
if lastFullBackup != nil {
lastBackupTime = &lastFullBackup.CreatedAt
}
now := time.Now().UTC()
nextTime := backupConfig.BackupInterval.NextTriggerTime(now, lastBackupTime)
return &backups_dto.GetNextFullBackupTimeResponse{
NextFullBackupTime: nextTime,
}, nil
}
// ReportError creates a FAILED backup record with the agent's error message.
func (s *PostgreWalBackupService) ReportError(
database *databases.Database,
errorMsg string,
) error {
if err := s.validateWalBackupType(database); err != nil {
return err
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
if err != nil {
return fmt.Errorf("failed to get backup config: %w", err)
}
if backupConfig.Storage == nil {
return fmt.Errorf("no storage configured for database %s", database.ID)
}
now := time.Now().UTC()
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: backupConfig.Storage.ID,
Status: backups_core.BackupStatusFailed,
FailMessage: &errorMsg,
Encryption: backupConfig.Encryption,
CreatedAt: now,
}
backup.GenerateFilename(database.Name)
if err := s.backupRepository.Save(backup); err != nil {
return fmt.Errorf("failed to save error backup record: %w", err)
}
return nil
}
func (s *PostgreWalBackupService) resolveFullBackup(
databaseID uuid.UUID,
backupID *uuid.UUID,
) (*backups_core.Backup, error) {
if backupID != nil {
return s.backupRepository.FindCompletedFullWalBackupByID(databaseID, *backupID)
}
return s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(databaseID)
}
func (s *PostgreWalBackupService) validateRestoreWalChain(
fullBackup *backups_core.Backup,
walSegments []*backups_core.Backup,
) *backups_dto.GetRestorePlanErrorResponse {
if len(walSegments) == 0 {
return nil
}
stopSegment := ""
if fullBackup.PgFullBackupWalStopSegmentName != nil {
stopSegment = *fullBackup.PgFullBackupWalStopSegmentName
}
walCalculator := util_wal.NewWalCalculator(0)
expectedNext, err := walCalculator.NextSegment(stopSegment)
if err != nil {
return nil
}
for _, seg := range walSegments {
segName := ""
if seg.PgWalSegmentName != nil {
segName = *seg.PgWalSegmentName
}
cmp, cmpErr := walCalculator.Compare(segName, stopSegment)
if cmpErr != nil {
return nil
}
// Skip segments that are <= stopSegment (they are part of the basebackup range)
if cmp <= 0 {
continue
}
if segName != expectedNext {
lastContiguous := stopSegment
// Walk back to find the segment before the gap
for _, prev := range walSegments {
prevName := ""
if prev.PgWalSegmentName != nil {
prevName = *prev.PgWalSegmentName
}
prevCmp, _ := walCalculator.Compare(prevName, stopSegment)
if prevCmp <= 0 {
continue
}
if prevName == segName {
break
}
lastContiguous = prevName
}
return &backups_dto.GetRestorePlanErrorResponse{
Error: "wal_chain_broken",
Message: fmt.Sprintf(
"WAL chain has a gap after segment %s. Recovery is only possible up to this segment.",
lastContiguous,
),
LastContiguousSegment: lastContiguous,
}
}
expectedNext, err = walCalculator.NextSegment(segName)
if err != nil {
return nil
}
}
return nil
}
func (s *PostgreWalBackupService) validateWalBackupType(database *databases.Database) error {
if database.Postgresql == nil ||
database.Postgresql.BackupType != postgresql.PostgresBackupTypeWalV1 {
return fmt.Errorf("database %s is not configured for WAL backups", database.ID)
}
return nil
}
type countingReader struct {
r io.Reader
n int64
}
func (cr *countingReader) Read(p []byte) (n int, err error) {
n, err = cr.r.Read(p)
cr.n += int64(n)
return
}

View File

@@ -1,4 +1,4 @@
package backups
package backups_services
import (
"encoding/base64"
@@ -11,6 +11,7 @@ import (
"databasus-backend/internal/features/backups/backups/backuping"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_download "databasus-backend/internal/features/backups/backups/download"
backups_dto "databasus-backend/internal/features/backups/backups/dto"
"databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -108,7 +109,7 @@ func (s *BackupService) GetBackups(
user *users_models.User,
databaseID uuid.UUID,
limit, offset int,
) (*GetBackupsResponse, error) {
) (*backups_dto.GetBackupsResponse, error) {
database, err := s.databaseService.GetDatabaseByID(databaseID)
if err != nil {
return nil, err
@@ -143,7 +144,7 @@ func (s *BackupService) GetBackups(
return nil, err
}
return &GetBackupsResponse{
return &backups_dto.GetBackupsResponse{
Backups: backups,
Total: total,
Limit: limit,
@@ -274,7 +275,7 @@ func (s *BackupService) GetBackupFile(
database.WorkspaceID,
)
reader, err := s.getBackupReader(backupID)
reader, err := s.GetBackupReader(backupID)
if err != nil {
return nil, nil, nil, err
}
@@ -282,39 +283,9 @@ func (s *BackupService) GetBackupFile(
return reader, backup, database, nil
}
func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
dbBackupsInProgress, err := s.backupRepository.FindByDatabaseIdAndStatus(
databaseID,
backups_core.BackupStatusInProgress,
)
if err != nil {
return err
}
if len(dbBackupsInProgress) > 0 {
return errors.New("backup is in progress, storage cannot be removed")
}
dbBackups, err := s.backupRepository.FindByDatabaseID(
databaseID,
)
if err != nil {
return err
}
for _, dbBackup := range dbBackups {
err := s.backupCleaner.DeleteBackup(dbBackup)
if err != nil {
return err
}
}
return nil
}
// GetBackupReader returns a reader for the backup file
// If encrypted, wraps with DecryptionReader
func (s *BackupService) getBackupReader(backupID uuid.UUID) (io.ReadCloser, error) {
// GetBackupReader returns a reader for the backup file.
// If encrypted, wraps with DecryptionReader.
func (s *BackupService) GetBackupReader(backupID uuid.UUID) (io.ReadCloser, error) {
backup, err := s.backupRepository.FindByID(backupID)
if err != nil {
return nil, fmt.Errorf("failed to find backup: %w", err)
@@ -394,7 +365,7 @@ func (s *BackupService) getBackupReader(backupID uuid.UUID) (io.ReadCloser, erro
s.logger.Info("Returning encrypted backup with decryption", "backupId", backupID)
return &DecryptionReaderCloser{
return &backups_dto.DecryptionReaderCloser{
DecryptionReader: decryptionReader,
BaseReader: fileReader,
}, nil
@@ -465,7 +436,7 @@ func (s *BackupService) GetBackupFileWithoutAuth(
return nil, nil, nil, err
}
reader, err := s.getBackupReader(backupID)
reader, err := s.GetBackupReader(backupID)
if err != nil {
return nil, nil, nil, err
}
@@ -501,6 +472,36 @@ func (s *BackupService) UnregisterDownload(userID uuid.UUID) {
s.downloadTokenService.UnregisterDownload(userID)
}
func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
dbBackupsInProgress, err := s.backupRepository.FindByDatabaseIdAndStatus(
databaseID,
backups_core.BackupStatusInProgress,
)
if err != nil {
return err
}
if len(dbBackupsInProgress) > 0 {
return errors.New("backup is in progress, storage cannot be removed")
}
dbBackups, err := s.backupRepository.FindByDatabaseID(
databaseID,
)
if err != nil {
return err
}
for _, dbBackup := range dbBackups {
err := s.backupCleaner.DeleteBackup(dbBackup)
if err != nil {
return err
}
}
return nil
}
func (s *BackupService) generateBackupFilename(
backup *backups_core.Backup,
database *databases.Database,

View File

@@ -2,7 +2,6 @@ package usecases_mariadb
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
@@ -437,40 +436,22 @@ func (uc *CreateMariadbBackupUsecase) setupBackupEncryption(
return storageWriter, nil, metadata, nil
}
salt, err := backup_encryption.GenerateSalt()
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to generate salt: %w", err)
}
nonce, err := backup_encryption.GenerateNonce()
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to generate nonce: %w", err)
}
masterKey, err := uc.secretKeyService.GetSecretKey()
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to get master key: %w", err)
}
encWriter, err := backup_encryption.NewEncryptionWriter(
storageWriter,
masterKey,
backupID,
salt,
nonce,
)
encSetup, err := backup_encryption.SetupEncryptionWriter(storageWriter, masterKey, backupID)
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to create encrypting writer: %w", err)
return nil, nil, metadata, err
}
saltBase64 := base64.StdEncoding.EncodeToString(salt)
nonceBase64 := base64.StdEncoding.EncodeToString(nonce)
metadata.EncryptionSalt = &saltBase64
metadata.EncryptionIV = &nonceBase64
metadata.EncryptionSalt = &encSetup.SaltBase64
metadata.EncryptionIV = &encSetup.NonceBase64
metadata.Encryption = backups_config.BackupEncryptionEncrypted
uc.logger.Info("Encryption enabled for backup", "backupId", backupID)
return encWriter, encWriter, metadata, nil
return encSetup.Writer, encSetup.Writer, metadata, nil
}
func (uc *CreateMariadbBackupUsecase) cleanupOnCancellation(

View File

@@ -2,7 +2,6 @@ package usecases_mongodb
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
@@ -277,41 +276,21 @@ func (uc *CreateMongodbBackupUsecase) setupBackupEncryption(
return storageWriter, nil, backupMetadata, nil
}
salt, err := backup_encryption.GenerateSalt()
if err != nil {
return nil, nil, backupMetadata, fmt.Errorf("failed to generate salt: %w", err)
}
nonce, err := backup_encryption.GenerateNonce()
if err != nil {
return nil, nil, backupMetadata, fmt.Errorf("failed to generate nonce: %w", err)
}
masterKey, err := uc.secretKeyService.GetSecretKey()
if err != nil {
return nil, nil, backupMetadata, fmt.Errorf("failed to get master key: %w", err)
}
encryptionWriter, err := backup_encryption.NewEncryptionWriter(
storageWriter,
masterKey,
backupID,
salt,
nonce,
)
encSetup, err := backup_encryption.SetupEncryptionWriter(storageWriter, masterKey, backupID)
if err != nil {
return nil, nil, backupMetadata, fmt.Errorf("failed to create encryption writer: %w", err)
return nil, nil, backupMetadata, err
}
saltBase64 := base64.StdEncoding.EncodeToString(salt)
nonceBase64 := base64.StdEncoding.EncodeToString(nonce)
backupMetadata.BackupID = backupID
backupMetadata.Encryption = backups_config.BackupEncryptionEncrypted
backupMetadata.EncryptionSalt = &saltBase64
backupMetadata.EncryptionIV = &nonceBase64
backupMetadata.EncryptionSalt = &encSetup.SaltBase64
backupMetadata.EncryptionIV = &encSetup.NonceBase64
return encryptionWriter, encryptionWriter, backupMetadata, nil
return encSetup.Writer, encSetup.Writer, backupMetadata, nil
}
func (uc *CreateMongodbBackupUsecase) copyWithShutdownCheck(

View File

@@ -2,7 +2,6 @@ package usecases_mysql
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
@@ -448,40 +447,22 @@ func (uc *CreateMysqlBackupUsecase) setupBackupEncryption(
return storageWriter, nil, metadata, nil
}
salt, err := backup_encryption.GenerateSalt()
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to generate salt: %w", err)
}
nonce, err := backup_encryption.GenerateNonce()
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to generate nonce: %w", err)
}
masterKey, err := uc.secretKeyService.GetSecretKey()
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to get master key: %w", err)
}
encWriter, err := backup_encryption.NewEncryptionWriter(
storageWriter,
masterKey,
backupID,
salt,
nonce,
)
encSetup, err := backup_encryption.SetupEncryptionWriter(storageWriter, masterKey, backupID)
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to create encrypting writer: %w", err)
return nil, nil, metadata, err
}
saltBase64 := base64.StdEncoding.EncodeToString(salt)
nonceBase64 := base64.StdEncoding.EncodeToString(nonce)
metadata.EncryptionSalt = &saltBase64
metadata.EncryptionIV = &nonceBase64
metadata.EncryptionSalt = &encSetup.SaltBase64
metadata.EncryptionIV = &encSetup.NonceBase64
metadata.Encryption = backups_config.BackupEncryptionEncrypted
uc.logger.Info("Encryption enabled for backup", "backupId", backupID)
return encWriter, encWriter, metadata, nil
return encSetup.Writer, encSetup.Writer, metadata, nil
}
func (uc *CreateMysqlBackupUsecase) cleanupOnCancellation(

View File

@@ -2,7 +2,6 @@ package usecases_postgresql
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
@@ -492,40 +491,22 @@ func (uc *CreatePostgresqlBackupUsecase) setupBackupEncryption(
return storageWriter, nil, metadata, nil
}
salt, err := backup_encryption.GenerateSalt()
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to generate salt: %w", err)
}
nonce, err := backup_encryption.GenerateNonce()
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to generate nonce: %w", err)
}
masterKey, err := uc.secretKeyService.GetSecretKey()
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to get master key: %w", err)
}
encWriter, err := backup_encryption.NewEncryptionWriter(
storageWriter,
masterKey,
backupID,
salt,
nonce,
)
encSetup, err := backup_encryption.SetupEncryptionWriter(storageWriter, masterKey, backupID)
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to create encrypting writer: %w", err)
return nil, nil, metadata, err
}
saltBase64 := base64.StdEncoding.EncodeToString(salt)
nonceBase64 := base64.StdEncoding.EncodeToString(nonce)
metadata.EncryptionSalt = &saltBase64
metadata.EncryptionIV = &nonceBase64
metadata.EncryptionSalt = &encSetup.SaltBase64
metadata.EncryptionIV = &encSetup.NonceBase64
metadata.Encryption = backups_config.BackupEncryptionEncrypted
uc.logger.Info("Encryption enabled for backup", "backupId", backupID)
return encWriter, encWriter, metadata, nil
return encSetup.Writer, encSetup.Writer, metadata, nil
}
func (uc *CreatePostgresqlBackupUsecase) cleanupOnCancellation(

View File

@@ -29,6 +29,11 @@ func (c *DatabaseController) RegisterRoutes(router *gin.RouterGroup) {
router.GET("/databases/notifier/:id/databases-count", c.CountDatabasesByNotifier)
router.POST("/databases/is-readonly", c.IsUserReadOnly)
router.POST("/databases/create-readonly-user", c.CreateReadOnlyUser)
router.POST("/databases/:id/regenerate-token", c.RegenerateAgentToken)
}
func (c *DatabaseController) RegisterPublicRoutes(router *gin.RouterGroup) {
router.POST("/databases/verify-token", c.VerifyAgentToken)
}
// CreateDatabase
@@ -438,3 +443,61 @@ func (c *DatabaseController) CreateReadOnlyUser(ctx *gin.Context) {
Password: password,
})
}
// RegenerateAgentToken
// @Summary Regenerate agent token for a database
// @Description Generate a new agent token for the database. The token is returned once and stored as a hash.
// @Tags databases
// @Produce json
// @Security BearerAuth
// @Param id path string true "Database ID"
// @Success 200 {object} map[string]string
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Router /databases/{id}/regenerate-token [post]
func (c *DatabaseController) RegenerateAgentToken(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
return
}
token, err := c.databaseService.RegenerateAgentToken(user, id)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, gin.H{"token": token})
}
// VerifyAgentToken
// @Summary Verify agent token
// @Description Verify that a given agent token is valid for any database
// @Tags databases
// @Accept json
// @Produce json
// @Param request body VerifyAgentTokenRequest true "Token to verify"
// @Success 200 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Router /databases/verify-token [post]
func (c *DatabaseController) VerifyAgentToken(ctx *gin.Context) {
var request VerifyAgentTokenRequest
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := c.databaseService.VerifyAgentToken(request.Token); err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
ctx.JSON(http.StatusOK, gin.H{"message": "token is valid"})
}

View File

@@ -13,10 +13,13 @@ import (
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
"databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/databases/databases/mariadb"
"databasus-backend/internal/features/databases/databases/mongodb"
"databasus-backend/internal/features/databases/databases/postgresql"
users_enums "databasus-backend/internal/features/users/enums"
users_middleware "databasus-backend/internal/features/users/middleware"
users_services "databasus-backend/internal/features/users/services"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
@@ -144,6 +147,66 @@ func Test_CreateDatabase_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testin
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
func Test_CreateDatabase_WalV1Type_NoConnectionFieldsRequired(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
request := Database{
Name: "Test WAL Database",
WorkspaceID: &workspace.ID,
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
BackupType: postgresql.PostgresBackupTypeWalV1,
CpuCount: 1,
},
}
var response Database
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/databases/create",
"Bearer "+owner.Token,
request,
http.StatusCreated,
&response,
)
defer RemoveTestDatabase(&response)
assert.Equal(t, "Test WAL Database", response.Name)
assert.NotEqual(t, uuid.Nil, response.ID)
}
func Test_CreateDatabase_PgDumpType_ConnectionFieldsRequired(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
request := Database{
Name: "Test PG_DUMP Database",
WorkspaceID: &workspace.ID,
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
BackupType: postgresql.PostgresBackupTypePgDump,
CpuCount: 1,
},
}
testResp := test_utils.MakePostRequest(
t,
router,
"/api/v1/databases/create",
"Bearer "+owner.Token,
request,
http.StatusBadRequest,
)
assert.Contains(t, string(testResp.Body), "host is required")
}
func Test_UpdateDatabase_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
@@ -256,6 +319,52 @@ func Test_UpdateDatabase_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testin
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
func Test_UpdateDatabase_WhenDatabaseTypeChanged_ReturnsBadRequest(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(database)
database.Type = DatabaseTypeMysql
testResp := test_utils.MakePostRequest(
t,
router,
"/api/v1/databases/update",
"Bearer "+owner.Token,
database,
http.StatusBadRequest,
)
assert.Contains(t, string(testResp.Body), "database type cannot be changed")
}
func Test_UpdateDatabase_WhenBackupTypeChanged_ReturnsBadRequest(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(database)
database.Postgresql.BackupType = postgresql.PostgresBackupTypeWalV1
testResp := test_utils.MakePostRequest(
t,
router,
"/api/v1/databases/update",
"Bearer "+owner.Token,
database,
http.StatusBadRequest,
)
assert.Contains(t, string(testResp.Body), "backup type cannot be changed")
}
func Test_DeleteDatabase_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
@@ -1050,6 +1159,87 @@ func Test_TestConnection_PermissionsEnforced(t *testing.T) {
}
}
func Test_RegenerateAgentToken_ReturnsToken(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(database)
var response map[string]string
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/databases/"+database.ID.String()+"/regenerate-token",
"Bearer "+owner.Token,
nil,
http.StatusOK,
&response,
)
assert.NotEmpty(t, response["token"])
assert.Len(t, response["token"], 32)
var updatedDatabase Database
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/databases/"+database.ID.String(),
"Bearer "+owner.Token,
http.StatusOK,
&updatedDatabase,
)
assert.True(t, updatedDatabase.IsAgentTokenGenerated)
}
func Test_VerifyAgentToken_WithValidToken_Succeeds(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(database)
var regenerateResponse map[string]string
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/databases/"+database.ID.String()+"/regenerate-token",
"Bearer "+owner.Token,
nil,
http.StatusOK,
&regenerateResponse,
)
token := regenerateResponse["token"]
assert.NotEmpty(t, token)
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/verify-token",
"",
VerifyAgentTokenRequest{Token: token},
)
assert.Equal(t, http.StatusOK, w.Code)
}
func Test_VerifyAgentToken_WithInvalidToken_Returns401(t *testing.T) {
router := createTestRouter()
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/verify-token",
"",
VerifyAgentTokenRequest{Token: "invalidtoken00000000000000000000"},
)
assert.Equal(t, http.StatusUnauthorized, w.Code)
}
func createTestDatabaseViaAPI(
name string,
workspaceID uuid.UUID,
@@ -1101,11 +1291,20 @@ func createTestDatabaseViaAPI(
}
func createTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
GetDatabaseController(),
)
gin.SetMode(gin.TestMode)
router := gin.New()
v1 := router.Group("/api/v1")
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
workspaces_controllers.GetWorkspaceController().RegisterRoutes(protected.(*gin.RouterGroup))
workspaces_controllers.GetMembershipController().RegisterRoutes(protected.(*gin.RouterGroup))
GetDatabaseController().RegisterRoutes(protected.(*gin.RouterGroup))
GetDatabaseController().RegisterPublicRoutes(v1)
audit_logs.SetupDependencies()
return router
}
@@ -1118,13 +1317,14 @@ func getTestPostgresConfig() *postgresql.PostgresqlDatabase {
testDbName := "testdb"
return &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
CpuCount: 1,
BackupType: postgresql.PostgresBackupTypePgDump,
Version: tools.PostgresqlVersion16,
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
CpuCount: 1,
}
}

View File

@@ -2,6 +2,7 @@ package postgresql
import (
"context"
"databasus-backend/internal/config"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/tools"
"errors"
@@ -17,6 +18,13 @@ import (
"gorm.io/gorm"
)
type PostgresBackupType string
const (
PostgresBackupTypePgDump PostgresBackupType = "PG_DUMP"
PostgresBackupTypeWalV1 PostgresBackupType = "WAL_V1"
)
type PostgresqlDatabase struct {
ID uuid.UUID `json:"id" gorm:"primaryKey;type:uuid;default:gen_random_uuid()"`
@@ -24,11 +32,13 @@ type PostgresqlDatabase struct {
Version tools.PostgresqlVersion `json:"version" gorm:"type:text;not null"`
// connection data
Host string `json:"host" gorm:"type:text;not null"`
Port int `json:"port" gorm:"type:int;not null"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
BackupType PostgresBackupType `json:"backupType" gorm:"column:backup_type;type:text;not null;default:'PG_DUMP'"`
// connection data — required for PG_DUMP, optional for WAL_V1
Host string `json:"host" gorm:"type:text"`
Port int `json:"port" gorm:"type:int"`
Username string `json:"username" gorm:"type:text"`
Password string `json:"password" gorm:"type:text"`
Database *string `json:"database" gorm:"type:text"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
@@ -66,20 +76,30 @@ func (p *PostgresqlDatabase) AfterFind(_ *gorm.DB) error {
}
func (p *PostgresqlDatabase) Validate() error {
if p.Host == "" {
return errors.New("host is required")
if p.BackupType == "" {
p.BackupType = PostgresBackupTypePgDump
}
if p.Port == 0 {
return errors.New("port is required")
if p.BackupType == PostgresBackupTypePgDump && config.GetEnv().IsCloud {
return errors.New("PG_DUMP backup type is not supported in cloud mode")
}
if p.Username == "" {
return errors.New("username is required")
}
if p.BackupType == PostgresBackupTypePgDump {
if p.Host == "" {
return errors.New("host is required")
}
if p.Password == "" {
return errors.New("password is required")
if p.Port == 0 {
return errors.New("port is required")
}
if p.Username == "" {
return errors.New("username is required")
}
if p.Password == "" {
return errors.New("password is required")
}
}
if p.CpuCount <= 0 {
@@ -90,7 +110,7 @@ func (p *PostgresqlDatabase) Validate() error {
// Databasus runs an internal PostgreSQL instance that should not be backed up through the UI
// because it would expose internal metadata to non-system administrators.
// To properly backup Databasus, see: https://databasus.com/faq#backup-databasus
if p.Database != nil && *p.Database != "" {
if p.BackupType == PostgresBackupTypePgDump && p.Database != nil && *p.Database != "" {
localhostHosts := []string{
"localhost",
"127.0.0.1",
@@ -130,6 +150,10 @@ func (p *PostgresqlDatabase) TestConnection(
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error {
if p.BackupType == PostgresBackupTypeWalV1 {
return errors.New("test connection is not supported for WAL backup type")
}
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
@@ -144,7 +168,21 @@ func (p *PostgresqlDatabase) HideSensitiveData() {
p.Password = ""
}
func (p *PostgresqlDatabase) ValidateUpdate(old *PostgresqlDatabase) error {
// BackupType cannot be changed after creation — the full backup structure
// (WAL hierarchy, storage files, cleanup logic) is built around
// the type chosen at creation time. Automatically migrating this state is
// error-prone; it is safer for the user to create a new database and
// remove the old one.
if old.BackupType != p.BackupType {
return errors.New("backup type cannot be changed; create a new database instead")
}
return nil
}
func (p *PostgresqlDatabase) Update(incoming *PostgresqlDatabase) {
p.BackupType = incoming.BackupType
p.Version = incoming.Version
p.Host = incoming.Host
p.Port = incoming.Port
@@ -181,6 +219,10 @@ func (p *PostgresqlDatabase) PopulateDbData(
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error {
if p.BackupType == PostgresBackupTypeWalV1 {
return nil
}
return p.PopulateVersion(logger, encryptor, databaseID)
}
@@ -243,6 +285,10 @@ func (p *PostgresqlDatabase) IsUserReadOnly(
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) (bool, []string, error) {
if p.BackupType == PostgresBackupTypeWalV1 {
return false, nil, errors.New("read-only check is not supported for WAL backup type")
}
password, err := decryptPasswordIfNeeded(p.Password, encryptor, databaseID)
if err != nil {
return false, nil, fmt.Errorf("failed to decrypt password: %w", err)
@@ -415,6 +461,10 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) (string, string, error) {
if p.BackupType == PostgresBackupTypeWalV1 {
return "", "", errors.New("read-only user creation is not supported for WAL backup type")
}
password, err := decryptPasswordIfNeeded(p.Password, encryptor, databaseID)
if err != nil {
return "", "", fmt.Errorf("failed to decrypt password: %w", err)

View File

@@ -9,3 +9,7 @@ type IsReadOnlyResponse struct {
IsReadOnly bool `json:"isReadOnly"`
Privileges []string `json:"privileges"`
}
type VerifyAgentTokenRequest struct {
Token string `json:"token" binding:"required"`
}

View File

@@ -37,6 +37,9 @@ type Database struct {
LastBackupErrorMessage *string `json:"lastBackupErrorMessage,omitempty" gorm:"column:last_backup_error_message;type:text"`
HealthStatus *HealthStatus `json:"healthStatus" gorm:"column:health_status;type:text;not null"`
AgentToken *string `json:"-" gorm:"column:agent_token;type:text"`
IsAgentTokenGenerated bool `json:"isAgentTokenGenerated" gorm:"column:is_agent_token_generated;not null;default:false"`
}
func (d *Database) Validate() error {
@@ -71,8 +74,19 @@ func (d *Database) Validate() error {
}
func (d *Database) ValidateUpdate(old, new Database) error {
// Database type cannot be changed after creation — the entire backup
// structure (storage files, schedulers, WAL hierarchy, etc.) is tied to
// the type at creation time. Recreating that state automatically is
// error-prone; it is safer for the user to create a new database and
// remove the old one.
if old.Type != new.Type {
return errors.New("database type is not allowed to change")
return errors.New("database type cannot be changed; create a new database instead")
}
if old.Type == DatabaseTypePostgres && old.Postgresql != nil && new.Postgresql != nil {
if err := new.Postgresql.ValidateUpdate(old.Postgresql); err != nil {
return err
}
}
return nil

View File

@@ -244,6 +244,18 @@ func (r *DatabaseRepository) GetAllDatabases() ([]*Database, error) {
return databases, nil
}
func (r *DatabaseRepository) FindByAgentTokenHash(hash string) (*Database, error) {
var database Database
if err := storage.GetDb().
Where("agent_token = ?", hash).
First(&database).Error; err != nil {
return nil, err
}
return &database, nil
}
func (r *DatabaseRepository) GetDatabasesIDsByNotifierID(
notifierID uuid.UUID,
) ([]uuid.UUID, error) {

View File

@@ -2,9 +2,11 @@ package databases
import (
"context"
"crypto/sha256"
"errors"
"fmt"
"log/slog"
"strings"
"time"
"databasus-backend/internal/config"
@@ -87,21 +89,8 @@ func (s *DatabaseService) CreateDatabase(
return nil, fmt.Errorf("failed to auto-detect database data: %w", err)
}
if config.GetEnv().IsCloud {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
isReadOnly, permissions, err := database.IsUserReadOnly(ctx, s.logger, s.fieldEncryptor)
if err != nil {
return nil, fmt.Errorf("failed to verify user permissions: %w", err)
}
if !isReadOnly {
return nil, fmt.Errorf(
"in cloud mode, only read-only database users are allowed (user has permissions: %v)",
permissions,
)
}
if err := s.verifyReadOnlyUserIfNeeded(database); err != nil {
return nil, err
}
if err := database.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
@@ -171,25 +160,8 @@ func (s *DatabaseService) UpdateDatabase(
return fmt.Errorf("failed to auto-detect database data: %w", err)
}
if config.GetEnv().IsCloud {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
isReadOnly, permissions, err := existingDatabase.IsUserReadOnly(
ctx,
s.logger,
s.fieldEncryptor,
)
if err != nil {
return fmt.Errorf("failed to verify user permissions: %w", err)
}
if !isReadOnly {
return fmt.Errorf(
"in cloud mode, only read-only database users are allowed (user has permissions: %v)",
permissions,
)
}
if err := s.verifyReadOnlyUserIfNeeded(existingDatabase); err != nil {
return err
}
oldName := existingDatabase.Name
@@ -485,6 +457,7 @@ func (s *DatabaseService) CopyDatabase(
newDatabase.Postgresql = &postgresql.PostgresqlDatabase{
ID: uuid.Nil,
DatabaseID: nil,
BackupType: existingDatabase.Postgresql.BackupType,
Version: existingDatabase.Postgresql.Version,
Host: existingDatabase.Postgresql.Host,
Port: existingDatabase.Postgresql.Port,
@@ -638,6 +611,71 @@ func (s *DatabaseService) SetHealthStatus(
return nil
}
func (s *DatabaseService) RegenerateAgentToken(
user *users_models.User,
databaseID uuid.UUID,
) (string, error) {
database, err := s.dbRepository.FindByID(databaseID)
if err != nil {
return "", err
}
if database.WorkspaceID == nil {
return "", errors.New("cannot regenerate token for database without workspace")
}
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user)
if err != nil {
return "", err
}
if !canManage {
return "", errors.New(
"insufficient permissions to regenerate agent token for this database",
)
}
plainToken := strings.ReplaceAll(uuid.New().String(), "-", "")
tokenHash := hashAgentToken(plainToken)
database.AgentToken = &tokenHash
database.IsAgentTokenGenerated = true
_, err = s.dbRepository.Save(database)
if err != nil {
return "", err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Agent token regenerated for database: %s", database.Name),
&user.ID,
database.WorkspaceID,
)
return plainToken, nil
}
func (s *DatabaseService) VerifyAgentToken(token string) error {
hash := hashAgentToken(token)
_, err := s.dbRepository.FindByAgentTokenHash(hash)
if err != nil {
return errors.New("invalid token")
}
return nil
}
func (s *DatabaseService) GetDatabaseByAgentToken(token string) (*Database, error) {
hash := hashAgentToken(token)
partial, err := s.dbRepository.FindByAgentTokenHash(hash)
if err != nil {
return nil, errors.New("invalid agent token")
}
return s.dbRepository.FindByID(partial.ID)
}
func (s *DatabaseService) OnBeforeWorkspaceDeletion(workspaceID uuid.UUID) error {
databases, err := s.dbRepository.FindByWorkspaceID(workspaceID)
if err != nil {
@@ -809,3 +847,36 @@ func (s *DatabaseService) CreateReadOnlyUser(
return username, password, nil
}
func (s *DatabaseService) verifyReadOnlyUserIfNeeded(database *Database) error {
if !config.GetEnv().IsCloud {
return nil
}
if database.Postgresql != nil &&
database.Postgresql.BackupType == postgresql.PostgresBackupTypeWalV1 {
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
isReadOnly, permissions, err := database.IsUserReadOnly(ctx, s.logger, s.fieldEncryptor)
if err != nil {
return fmt.Errorf("failed to verify user permissions: %w", err)
}
if !isReadOnly {
return fmt.Errorf(
"in cloud mode, only read-only database users are allowed (user has permissions: %v)",
permissions,
)
}
return nil
}
func hashAgentToken(token string) string {
hash := sha256.Sum256([]byte(token))
return fmt.Sprintf("%x", hash)
}

View File

@@ -79,6 +79,38 @@ func (i *Interval) ShouldTriggerBackup(now time.Time, lastBackupTime *time.Time)
}
}
// NextTriggerTime computes the next time a backup should trigger based on the interval and last backup time.
// Returns nil when a backup is due immediately (no previous backup exists).
func (i *Interval) NextTriggerTime(now time.Time, lastBackupTime *time.Time) *time.Time {
if lastBackupTime == nil {
return nil
}
switch i.Interval {
case IntervalHourly:
next := lastBackupTime.Add(time.Hour)
return &next
case IntervalDaily:
next := i.nextDailyTrigger(now)
return &next
case IntervalWeekly:
next := i.nextWeeklyTrigger(now)
return &next
case IntervalMonthly:
next := i.nextMonthlyTrigger(now)
return &next
case IntervalCron:
return i.nextCronTrigger(*lastBackupTime)
default:
return nil
}
}
func (i *Interval) Copy() *Interval {
return &Interval{
ID: uuid.Nil,
@@ -240,6 +272,99 @@ func (i *Interval) shouldTriggerCron(now, lastBackup time.Time) bool {
return now.After(nextAfterLastBackup) || now.Equal(nextAfterLastBackup)
}
func (i *Interval) nextDailyTrigger(now time.Time) time.Time {
t, err := time.Parse("15:04", *i.TimeOfDay)
if err != nil {
return now
}
todaySlot := time.Date(
now.Year(), now.Month(), now.Day(),
t.Hour(), t.Minute(), 0, 0, now.Location(),
)
if now.Before(todaySlot) {
return todaySlot
}
return todaySlot.AddDate(0, 0, 1)
}
func (i *Interval) nextWeeklyTrigger(now time.Time) time.Time {
targetWd := time.Weekday(0)
if i.Weekday != nil {
targetWd = time.Weekday(*i.Weekday)
}
startOfWeek := getStartOfWeek(now)
var daysFromMonday int
if targetWd == time.Sunday {
daysFromMonday = 6
} else {
daysFromMonday = int(targetWd) - 1
}
targetThisWeek := startOfWeek.AddDate(0, 0, daysFromMonday)
if i.TimeOfDay != nil {
t, err := time.Parse("15:04", *i.TimeOfDay)
if err == nil {
targetThisWeek = time.Date(
targetThisWeek.Year(), targetThisWeek.Month(), targetThisWeek.Day(),
t.Hour(), t.Minute(), 0, 0, targetThisWeek.Location(),
)
}
}
if now.Before(targetThisWeek) {
return targetThisWeek
}
return targetThisWeek.AddDate(0, 0, 7)
}
func (i *Interval) nextMonthlyTrigger(now time.Time) time.Time {
day := 1
if i.DayOfMonth != nil {
day = *i.DayOfMonth
}
targetThisMonth := time.Date(now.Year(), now.Month(), day, 0, 0, 0, 0, now.Location())
if i.TimeOfDay != nil {
t, err := time.Parse("15:04", *i.TimeOfDay)
if err == nil {
targetThisMonth = time.Date(
targetThisMonth.Year(), targetThisMonth.Month(), targetThisMonth.Day(),
t.Hour(), t.Minute(), 0, 0, targetThisMonth.Location(),
)
}
}
if now.Before(targetThisMonth) {
return targetThisMonth
}
return targetThisMonth.AddDate(0, 1, 0)
}
func (i *Interval) nextCronTrigger(lastBackup time.Time) *time.Time {
if i.CronExpression == nil || *i.CronExpression == "" {
return nil
}
parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
schedule, err := parser.Parse(*i.CronExpression)
if err != nil {
return nil
}
next := schedule.Next(lastBackup)
return &next
}
func (i *Interval) validateCronExpression(expr string) error {
parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
_, err := parser.Parse(expr)

View File

@@ -721,3 +721,265 @@ func TestInterval_Validate(t *testing.T) {
assert.NoError(t, err)
})
}
func TestInterval_NextTriggerTime_NilLastBackup(t *testing.T) {
now := time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC)
t.Run("Hourly with nil lastBackup returns nil", func(t *testing.T) {
interval := &Interval{ID: uuid.New(), Interval: IntervalHourly}
result := interval.NextTriggerTime(now, nil)
assert.Nil(t, result)
})
t.Run("Daily with nil lastBackup returns nil", func(t *testing.T) {
timeOfDay := "09:00"
interval := &Interval{ID: uuid.New(), Interval: IntervalDaily, TimeOfDay: &timeOfDay}
result := interval.NextTriggerTime(now, nil)
assert.Nil(t, result)
})
t.Run("Weekly with nil lastBackup returns nil", func(t *testing.T) {
timeOfDay := "15:00"
weekday := 3
interval := &Interval{
ID: uuid.New(),
Interval: IntervalWeekly,
TimeOfDay: &timeOfDay,
Weekday: &weekday,
}
result := interval.NextTriggerTime(now, nil)
assert.Nil(t, result)
})
t.Run("Monthly with nil lastBackup returns nil", func(t *testing.T) {
timeOfDay := "08:00"
dayOfMonth := 10
interval := &Interval{
ID: uuid.New(),
Interval: IntervalMonthly,
TimeOfDay: &timeOfDay,
DayOfMonth: &dayOfMonth,
}
result := interval.NextTriggerTime(now, nil)
assert.Nil(t, result)
})
t.Run("Cron with nil lastBackup returns nil", func(t *testing.T) {
cronExpr := "0 2 * * *"
interval := &Interval{ID: uuid.New(), Interval: IntervalCron, CronExpression: &cronExpr}
result := interval.NextTriggerTime(now, nil)
assert.Nil(t, result)
})
}
func TestInterval_NextTriggerTime_Hourly(t *testing.T) {
interval := &Interval{ID: uuid.New(), Interval: IntervalHourly}
now := time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC)
t.Run("Returns lastBackup + 1 hour", func(t *testing.T) {
lastBackup := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.NotNil(t, result)
assert.Equal(t, time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC), *result)
})
t.Run("Returns future time when last backup was recent", func(t *testing.T) {
lastBackup := time.Date(2024, 1, 15, 11, 30, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.NotNil(t, result)
assert.Equal(t, time.Date(2024, 1, 15, 12, 30, 0, 0, time.UTC), *result)
})
}
func TestInterval_NextTriggerTime_Daily(t *testing.T) {
timeOfDay := "09:00"
interval := &Interval{ID: uuid.New(), Interval: IntervalDaily, TimeOfDay: &timeOfDay}
t.Run("Before today's slot: returns today's slot", func(t *testing.T) {
now := time.Date(2024, 1, 15, 8, 0, 0, 0, time.UTC)
lastBackup := time.Date(2024, 1, 14, 9, 0, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.NotNil(t, result)
assert.Equal(t, time.Date(2024, 1, 15, 9, 0, 0, 0, time.UTC), *result)
})
t.Run("After today's slot: returns tomorrow's slot", func(t *testing.T) {
now := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC)
lastBackup := time.Date(2024, 1, 15, 9, 0, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.NotNil(t, result)
assert.Equal(t, time.Date(2024, 1, 16, 9, 0, 0, 0, time.UTC), *result)
})
t.Run("Exactly at today's slot: returns tomorrow's slot", func(t *testing.T) {
now := time.Date(2024, 1, 15, 9, 0, 0, 0, time.UTC)
lastBackup := time.Date(2024, 1, 14, 9, 0, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.NotNil(t, result)
assert.Equal(t, time.Date(2024, 1, 16, 9, 0, 0, 0, time.UTC), *result)
})
}
func TestInterval_NextTriggerTime_Weekly(t *testing.T) {
timeOfDay := "15:00"
weekday := 3 // Wednesday
interval := &Interval{
ID: uuid.New(),
Interval: IntervalWeekly,
TimeOfDay: &timeOfDay,
Weekday: &weekday,
}
t.Run("Before this week's target: returns this week's target", func(t *testing.T) {
// Tuesday Jan 16, 2024
now := time.Date(2024, 1, 16, 10, 0, 0, 0, time.UTC)
lastBackup := time.Date(2024, 1, 10, 15, 0, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.NotNil(t, result)
// Wednesday Jan 17 at 15:00
assert.Equal(t, time.Date(2024, 1, 17, 15, 0, 0, 0, time.UTC), *result)
})
t.Run("After this week's target: returns next week's target", func(t *testing.T) {
// Thursday Jan 18, 2024
now := time.Date(2024, 1, 18, 10, 0, 0, 0, time.UTC)
lastBackup := time.Date(2024, 1, 17, 15, 0, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.NotNil(t, result)
// Next Wednesday Jan 24 at 15:00
assert.Equal(t, time.Date(2024, 1, 24, 15, 0, 0, 0, time.UTC), *result)
})
t.Run("Friday interval: returns correct target", func(t *testing.T) {
fridayTimeOfDay := "00:00"
fridayWeekday := 5 // Friday
fridayInterval := &Interval{
ID: uuid.New(),
Interval: IntervalWeekly,
TimeOfDay: &fridayTimeOfDay,
Weekday: &fridayWeekday,
}
// Wednesday Jan 17, 2024
now := time.Date(2024, 1, 17, 10, 0, 0, 0, time.UTC)
lastBackup := time.Date(2024, 1, 12, 0, 0, 0, 0, time.UTC)
result := fridayInterval.NextTriggerTime(now, &lastBackup)
assert.NotNil(t, result)
// Friday Jan 19 at 00:00
assert.Equal(t, time.Date(2024, 1, 19, 0, 0, 0, 0, time.UTC), *result)
})
}
func TestInterval_NextTriggerTime_Monthly(t *testing.T) {
timeOfDay := "08:00"
dayOfMonth := 10
interval := &Interval{
ID: uuid.New(),
Interval: IntervalMonthly,
TimeOfDay: &timeOfDay,
DayOfMonth: &dayOfMonth,
}
t.Run("Before this month's target: returns this month's target", func(t *testing.T) {
now := time.Date(2024, 1, 5, 10, 0, 0, 0, time.UTC)
lastBackup := time.Date(2023, 12, 10, 8, 0, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.NotNil(t, result)
assert.Equal(t, time.Date(2024, 1, 10, 8, 0, 0, 0, time.UTC), *result)
})
t.Run("After this month's target: returns next month's target", func(t *testing.T) {
now := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC)
lastBackup := time.Date(2024, 1, 10, 8, 0, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.NotNil(t, result)
assert.Equal(t, time.Date(2024, 2, 10, 8, 0, 0, 0, time.UTC), *result)
})
t.Run("Exactly at this month's target: returns next month's target", func(t *testing.T) {
now := time.Date(2024, 1, 10, 8, 0, 0, 0, time.UTC)
lastBackup := time.Date(2023, 12, 10, 8, 0, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.NotNil(t, result)
assert.Equal(t, time.Date(2024, 2, 10, 8, 0, 0, 0, time.UTC), *result)
})
}
func TestInterval_NextTriggerTime_Cron(t *testing.T) {
t.Run("Daily cron: returns next trigger after lastBackup", func(t *testing.T) {
cronExpr := "0 2 * * *" // Daily at 2:00 AM
interval := &Interval{ID: uuid.New(), Interval: IntervalCron, CronExpression: &cronExpr}
now := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC)
lastBackup := time.Date(2024, 1, 14, 2, 0, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.NotNil(t, result)
assert.Equal(t, time.Date(2024, 1, 15, 2, 0, 0, 0, time.UTC), *result)
})
t.Run("Complex cron: 1st and 15th at 4:30", func(t *testing.T) {
cronExpr := "30 4 1,15 * *"
interval := &Interval{ID: uuid.New(), Interval: IntervalCron, CronExpression: &cronExpr}
now := time.Date(2024, 1, 10, 10, 0, 0, 0, time.UTC)
lastBackup := time.Date(2024, 1, 1, 4, 30, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.NotNil(t, result)
assert.Equal(t, time.Date(2024, 1, 15, 4, 30, 0, 0, time.UTC), *result)
})
t.Run("Invalid cron expression returns nil", func(t *testing.T) {
invalidCron := "invalid cron"
interval := &Interval{ID: uuid.New(), Interval: IntervalCron, CronExpression: &invalidCron}
now := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC)
lastBackup := time.Date(2024, 1, 14, 10, 0, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.Nil(t, result)
})
t.Run("Empty cron expression returns nil", func(t *testing.T) {
emptyCron := ""
interval := &Interval{ID: uuid.New(), Interval: IntervalCron, CronExpression: &emptyCron}
now := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC)
lastBackup := time.Date(2024, 1, 14, 10, 0, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.Nil(t, result)
})
t.Run("Nil cron expression returns nil", func(t *testing.T) {
interval := &Interval{ID: uuid.New(), Interval: IntervalCron, CronExpression: nil}
now := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC)
lastBackup := time.Date(2024, 1, 14, 10, 0, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.Nil(t, result)
})
}
func TestInterval_NextTriggerTime_UnknownInterval(t *testing.T) {
interval := &Interval{ID: uuid.New(), Interval: IntervalType("UNKNOWN")}
now := time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC)
lastBackup := time.Date(2024, 1, 14, 12, 0, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.Nil(t, result)
}

View File

@@ -18,7 +18,7 @@ import (
env_config "databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
backups_controllers "databasus-backend/internal/features/backups/backups/controllers"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -440,7 +440,7 @@ func Test_CancelRestore_InProgressRestore_SuccessfullyCancelled(t *testing.T) {
}()
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
backup := backups.CreateTestBackup(database.ID, storage.ID)
backup := backups_controllers.CreateTestBackup(database.ID, storage.ID)
mockUsecase := &restoring.MockBlockingRestoreUsecase{
StartedChan: make(chan bool, 1),

View File

@@ -5,8 +5,8 @@ import (
"sync/atomic"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
"databasus-backend/internal/features/backups/backups/backuping"
backups_services "databasus-backend/internal/features/backups/backups/services"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/disk"
@@ -21,7 +21,7 @@ import (
var restoreRepository = &restores_core.RestoreRepository{}
var restoreService = &RestoreService{
backups.GetBackupService(),
backups_services.GetBackupService(),
restoreRepository,
storages.GetStorageService(),
backups_config.GetBackupConfigService(),
@@ -51,7 +51,7 @@ func SetupDependencies() {
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
backups.GetBackupService().AddBackupRemoveListener(restoreService)
backups_services.GetBackupService().AddBackupRemoveListener(restoreService)
backuping.GetBackupCleaner().AddBackupRemoveListener(restoreService)
isSetup.Store(true)

View File

@@ -7,7 +7,7 @@ import (
"github.com/google/uuid"
"databasus-backend/internal/features/backups/backups"
backups_services "databasus-backend/internal/features/backups/backups/services"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
restores_core "databasus-backend/internal/features/restores/core"
@@ -39,37 +39,37 @@ var restoreDatabaseCache = cache_utils.NewCacheUtil[RestoreDatabaseCache](
var restoreCancelManager = tasks_cancellation.GetTaskCancelManager()
var restorerNode = &RestorerNode{
nodeID: uuid.New(),
databaseService: databases.GetDatabaseService(),
backupService: backups.GetBackupService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
restoreRepository: restoreRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
restoreNodesRegistry: restoreNodesRegistry,
logger: logger.GetLogger(),
restoreBackupUsecase: usecases.GetRestoreBackupUsecase(),
cacheUtil: restoreDatabaseCache,
restoreCancelManager: restoreCancelManager,
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
uuid.New(),
databases.GetDatabaseService(),
backups_services.GetBackupService(),
encryption.GetFieldEncryptor(),
restoreRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
restoreNodesRegistry,
logger.GetLogger(),
usecases.GetRestoreBackupUsecase(),
restoreDatabaseCache,
restoreCancelManager,
time.Time{},
sync.Once{},
atomic.Bool{},
}
var restoresScheduler = &RestoresScheduler{
restoreRepository: restoreRepository,
backupService: backups.GetBackupService(),
storageService: storages.GetStorageService(),
backupConfigService: backups_config.GetBackupConfigService(),
restoreNodesRegistry: restoreNodesRegistry,
lastCheckTime: time.Now().UTC(),
logger: logger.GetLogger(),
restoreToNodeRelations: make(map[uuid.UUID]RestoreToNodeRelation),
restorerNode: restorerNode,
cacheUtil: restoreDatabaseCache,
completionSubscriptionID: uuid.Nil,
runOnce: sync.Once{},
hasRun: atomic.Bool{},
restoreRepository,
backups_services.GetBackupService(),
storages.GetStorageService(),
backups_config.GetBackupConfigService(),
restoreNodesRegistry,
time.Now().UTC(),
logger.GetLogger(),
make(map[uuid.UUID]RestoreToNodeRelation),
restorerNode,
restoreDatabaseCache,
uuid.Nil,
sync.Once{},
atomic.Bool{},
}
func GetRestoresScheduler() *RestoresScheduler {

View File

@@ -13,7 +13,7 @@ import (
"github.com/google/uuid"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_services "databasus-backend/internal/features/backups/backups/services"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
restores_core "databasus-backend/internal/features/restores/core"
@@ -32,7 +32,7 @@ type RestorerNode struct {
nodeID uuid.UUID
databaseService *databases.DatabaseService
backupService *backups.BackupService
backupService *backups_services.BackupService
fieldEncryptor util_encryption.FieldEncryptor
restoreRepository *restores_core.RestoreRepository
backupConfigService *backups_config.BackupConfigService

View File

@@ -7,7 +7,7 @@ import (
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_controllers "databasus-backend/internal/features/backups/backups/controllers"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -58,7 +58,7 @@ func Test_MakeRestore_WhenCacheMissed_RestoreFails(t *testing.T) {
cache_utils.ClearAllCache()
}()
backup := backups.CreateTestBackup(database.ID, storage.ID)
backup := backups_controllers.CreateTestBackup(database.ID, storage.ID)
// Create restore but DON'T cache DB credentials
// Also don't set embedded DB fields to avoid schema issues
@@ -126,7 +126,7 @@ func Test_MakeRestore_WhenTaskStarts_CacheDeletedImmediately(t *testing.T) {
cache_utils.ClearAllCache()
}()
backup := backups.CreateTestBackup(database.ID, storage.ID)
backup := backups_controllers.CreateTestBackup(database.ID, storage.ID)
// Create restore with cached DB credentials
// Don't set embedded DB fields in the restore model itself

View File

@@ -11,7 +11,7 @@ import (
"github.com/google/uuid"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_services "databasus-backend/internal/features/backups/backups/services"
backups_config "databasus-backend/internal/features/backups/config"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
@@ -26,7 +26,7 @@ const (
type RestoresScheduler struct {
restoreRepository *restores_core.RestoreRepository
backupService *backups.BackupService
backupService *backups_services.BackupService
storageService *storages.StorageService
backupConfigService *backups_config.BackupConfigService
restoreNodesRegistry *RestoreNodesRegistry

View File

@@ -5,7 +5,7 @@ import (
"time"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_controllers "databasus-backend/internal/features/backups/backups/controllers"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -68,7 +68,7 @@ func Test_CheckDeadNodesAndFailRestores_NodeDies_FailsRestoreAndCleansUpRegistry
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
// Create a test backup
backup := backups.CreateTestBackup(database.ID, storage.ID)
backup := backups_controllers.CreateTestBackup(database.ID, storage.ID)
var err error
// Register mock node without subscribing to restores (simulates node crash after registration)
@@ -171,7 +171,7 @@ func Test_OnRestoreCompleted_TaskIsNotRestore_SkipsProcessing(t *testing.T) {
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
// Create a test backup
backup := backups.CreateTestBackup(database.ID, storage.ID)
backup := backups_controllers.CreateTestBackup(database.ID, storage.ID)
// Register mock node
mockNodeID = uuid.New()
@@ -357,7 +357,7 @@ func Test_FailRestoresInProgress_SchedulerStarts_UpdatesStatus(t *testing.T) {
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
// Create a test backup
backup := backups.CreateTestBackup(database.ID, storage.ID)
backup := backups_controllers.CreateTestBackup(database.ID, storage.ID)
// Create two in-progress restores that should be failed on scheduler restart
restore1 := &restores_core.Restore{
@@ -465,7 +465,7 @@ func Test_StartRestore_RestoreCompletes_DecrementsActiveTaskCount(t *testing.T)
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
// Create a test backup
backup := backups.CreateTestBackup(database.ID, storage.ID)
backup := backups_controllers.CreateTestBackup(database.ID, storage.ID)
// Get initial active task count
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
@@ -566,7 +566,7 @@ func Test_StartRestore_RestoreFails_DecrementsActiveTaskCount(t *testing.T) {
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
// Create a test backup
backup := backups.CreateTestBackup(database.ID, storage.ID)
backup := backups_controllers.CreateTestBackup(database.ID, storage.ID)
// Get initial active task count
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
@@ -664,7 +664,7 @@ func Test_StartRestore_CredentialsStoredEncryptedInCache(t *testing.T) {
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
// Create a test backup
backup := backups.CreateTestBackup(database.ID, storage.ID)
backup := backups_controllers.CreateTestBackup(database.ID, storage.ID)
// Register mock node so scheduler can assign restore to it
mockNodeID = uuid.New()
@@ -779,7 +779,7 @@ func Test_StartRestore_CredentialsRemovedAfterRestoreStarts(t *testing.T) {
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
// Create a test backup
backup := backups.CreateTestBackup(database.ID, storage.ID)
backup := backups_controllers.CreateTestBackup(database.ID, storage.ID)
// Create restore with credentials
plaintextPassword := "test_password_456"

View File

@@ -12,8 +12,8 @@ import (
"github.com/google/uuid"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_services "databasus-backend/internal/features/backups/backups/services"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/postgresql"
@@ -40,48 +40,48 @@ func CreateTestRouter() *gin.Engine {
func CreateTestRestorerNode() *RestorerNode {
return &RestorerNode{
nodeID: uuid.New(),
databaseService: databases.GetDatabaseService(),
backupService: backups.GetBackupService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
restoreRepository: restoreRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
restoreNodesRegistry: restoreNodesRegistry,
logger: logger.GetLogger(),
restoreBackupUsecase: usecases.GetRestoreBackupUsecase(),
cacheUtil: restoreDatabaseCache,
restoreCancelManager: tasks_cancellation.GetTaskCancelManager(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
uuid.New(),
databases.GetDatabaseService(),
backups_services.GetBackupService(),
encryption.GetFieldEncryptor(),
restoreRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
restoreNodesRegistry,
logger.GetLogger(),
usecases.GetRestoreBackupUsecase(),
restoreDatabaseCache,
tasks_cancellation.GetTaskCancelManager(),
time.Time{},
sync.Once{},
atomic.Bool{},
}
}
func CreateTestRestorerNodeWithUsecase(usecase restores_core.RestoreBackupUsecase) *RestorerNode {
return &RestorerNode{
nodeID: uuid.New(),
databaseService: databases.GetDatabaseService(),
backupService: backups.GetBackupService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
restoreRepository: restoreRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
restoreNodesRegistry: restoreNodesRegistry,
logger: logger.GetLogger(),
restoreBackupUsecase: usecase,
cacheUtil: restoreDatabaseCache,
restoreCancelManager: tasks_cancellation.GetTaskCancelManager(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
uuid.New(),
databases.GetDatabaseService(),
backups_services.GetBackupService(),
encryption.GetFieldEncryptor(),
restoreRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
restoreNodesRegistry,
logger.GetLogger(),
usecase,
restoreDatabaseCache,
tasks_cancellation.GetTaskCancelManager(),
time.Time{},
sync.Once{},
atomic.Bool{},
}
}
func CreateTestRestoresScheduler() *RestoresScheduler {
return &RestoresScheduler{
restoreRepository,
backups.GetBackupService(),
backups_services.GetBackupService(),
storages.GetStorageService(),
backups_config.GetBackupConfigService(),
restoreNodesRegistry,

View File

@@ -3,8 +3,8 @@ package restores
import (
"databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_services "databasus-backend/internal/features/backups/backups/services"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/disk"
@@ -26,7 +26,7 @@ import (
)
type RestoreService struct {
backupService *backups.BackupService
backupService *backups_services.BackupService
restoreRepository *restores_core.RestoreRepository
storageService *storages.StorageService
backupConfigService *backups_config.BackupConfigService

View File

@@ -8,7 +8,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"databasus-backend/internal/features/backups/backups"
backups_controllers "databasus-backend/internal/features/backups/backups/controllers"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/restores/restoring"
@@ -22,12 +22,12 @@ func CreateTestRouter() *gin.Engine {
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
backups_config.GetBackupConfigController(),
backups.GetBackupController(),
backups_controllers.GetBackupController(),
GetRestoreController(),
)
v1 := router.Group("/api/v1")
backups.GetBackupController().RegisterPublicRoutes(v1)
backups_controllers.GetBackupController().RegisterPublicRoutes(v1)
return router
}

View File

@@ -47,14 +47,15 @@ func (l *LocalStorage) SaveFile(
logger.Info("Starting to save file to local storage", "fileName", fileName)
tempFilePath := filepath.Join(config.GetEnv().TempFolder, fileName)
err := files_utils.EnsureDirectories([]string{
config.GetEnv().TempFolder,
filepath.Dir(tempFilePath),
})
if err != nil {
return fmt.Errorf("failed to ensure directories: %w", err)
}
tempFilePath := filepath.Join(config.GetEnv().TempFolder, fileName)
logger.Debug("Creating temp file", "fileName", fileName, "tempPath", tempFilePath)
tempFile, err := os.Create(tempFilePath)
@@ -101,6 +102,10 @@ func (l *LocalStorage) SaveFile(
finalPath,
)
if err = files_utils.EnsureDirectories([]string{filepath.Dir(finalPath)}); err != nil {
return fmt.Errorf("failed to ensure final directory: %w", err)
}
// Move the file from temp to backups directory
if err = os.Rename(tempFilePath, finalPath); err != nil {
logger.Error(

View File

@@ -8,8 +8,8 @@ import (
"time"
"databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
"databasus-backend/internal/features/backups/backups/backuping"
backups_services "databasus-backend/internal/features/backups/backups/services"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
@@ -26,8 +26,8 @@ func Test_SetupDependencies_CalledTwice_LogsWarning(t *testing.T) {
audit_logs.SetupDependencies()
audit_logs.SetupDependencies()
backups.SetupDependencies()
backups.SetupDependencies()
backups_services.SetupDependencies()
backups_services.SetupDependencies()
backups_config.SetupDependencies()
backups_config.SetupDependencies()

View File

@@ -17,8 +17,9 @@ import (
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_controllers "databasus-backend/internal/features/backups/backups/controllers"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_dto "databasus-backend/internal/features/backups/backups/dto"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
pgtypes "databasus-backend/internal/features/databases/databases/postgresql"
@@ -1234,7 +1235,7 @@ func createTestRouter() *gin.Engine {
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
backups_config.GetBackupConfigController(),
backups.GetBackupController(),
backups_controllers.GetBackupController(),
restores.GetRestoreController(),
)
return router
@@ -1255,7 +1256,7 @@ func waitForBackupCompletion(
t.Fatalf("Timeout waiting for backup completion after %v", timeout)
}
var response backups.GetBackupsResponse
var response backups_dto.GetBackupsResponse
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
@@ -1431,7 +1432,7 @@ func createBackupViaAPI(
databaseID uuid.UUID,
token string,
) {
request := backups.MakeBackupRequest{DatabaseID: databaseID}
request := backups_dto.MakeBackupRequest{DatabaseID: databaseID}
test_utils.MakePostRequest(
t,
router,

View File

@@ -0,0 +1,125 @@
package wal
import (
"encoding/hex"
"errors"
"fmt"
"strings"
)
const (
segmentNameLen = 24
timelineLen = 8
logLen = 8
segLen = 8
defaultSegmentSizeBytes = 16 * 1024 * 1024 // 16 MB
)
// WalCalculator performs WAL segment name arithmetic for a given WAL segment size.
//
// WAL segment name format: TTTTTTTTLLLLLLLLSSSSSSSS (24 hex chars)
// - TT: timeline (8 hex digits)
// - LL: log file / XLogId (8 hex digits)
// - SS: segment within log file (8 hex digits)
//
// segmentsPerXLogId = 0x100000000 / segmentSizeBytes
// Increment SS; if SS >= segmentsPerXLogId → SS = 0, LL++
type WalCalculator struct {
segmentSizeBytes int64
segmentsPerXLogId uint64
}
// NewWalCalculator creates a WalCalculator for the given WAL segment size in bytes.
// Pass 0 or a negative value to use the PostgreSQL default of 16 MB.
func NewWalCalculator(segmentSizeBytes int64) *WalCalculator {
if segmentSizeBytes <= 0 {
segmentSizeBytes = defaultSegmentSizeBytes
}
return &WalCalculator{
segmentSizeBytes: segmentSizeBytes,
segmentsPerXLogId: uint64(0x100000000) / uint64(segmentSizeBytes),
}
}
// NextSegment computes the next WAL segment name after current.
// Returns an error if current is not a valid 24-character hex WAL segment name.
func (c *WalCalculator) NextSegment(current string) (string, error) {
if !c.IsValidSegmentName(current) {
return "", fmt.Errorf("invalid WAL segment name: %q", current)
}
timeline := current[:timelineLen]
logHex := current[timelineLen : timelineLen+logLen]
segHex := current[timelineLen+logLen:]
logVal, err := parseHex32(logHex)
if err != nil {
return "", fmt.Errorf("parse log part of %q: %w", current, err)
}
segVal, err := parseHex32(segHex)
if err != nil {
return "", fmt.Errorf("parse seg part of %q: %w", current, err)
}
segVal++
if uint64(segVal) >= c.segmentsPerXLogId {
segVal = 0
logVal++
}
return fmt.Sprintf("%s%08X%08X", strings.ToUpper(timeline), logVal, segVal), nil
}
// IsValidSegmentName returns true if name is a 24-character uppercase (or lowercase) hex string
// representing a valid WAL segment name.
func (c *WalCalculator) IsValidSegmentName(name string) bool {
if len(name) != segmentNameLen {
return false
}
_, err := hex.DecodeString(name)
return err == nil
}
// Compare compares two WAL segment names a and b by their numeric value.
// Returns -1 if a < b, 0 if a == b, 1 if a > b.
// Both names must be valid; returns an error otherwise.
// Timeline is compared first, then log file, then segment number.
func (c *WalCalculator) Compare(a, b string) (int, error) {
if !c.IsValidSegmentName(a) {
return 0, fmt.Errorf("invalid WAL segment name: %q", a)
}
if !c.IsValidSegmentName(b) {
return 0, fmt.Errorf("invalid WAL segment name: %q", b)
}
// Fixed-width uppercase hex: lexicographic order equals numeric order.
aUpper := strings.ToUpper(a)
bUpper := strings.ToUpper(b)
if aUpper < bUpper {
return -1, nil
}
if aUpper > bUpper {
return 1, nil
}
return 0, nil
}
func parseHex32(s string) (uint32, error) {
if len(s) != 8 {
return 0, errors.New("expected 8 hex characters")
}
b, err := hex.DecodeString(s)
if err != nil {
return 0, err
}
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3]), nil
}

View File

@@ -0,0 +1,221 @@
package wal
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
mb1 = 1 * 1024 * 1024
mb16 = 16 * 1024 * 1024
mb64 = 64 * 1024 * 1024
)
// NextSegment — no wrap
func Test_WalCalculator_NextSegment_DefaultSegmentSize_NoWrap(t *testing.T) {
calc := NewWalCalculator(mb16)
cases := []struct {
input string
expected string
}{
{"000000010000000100000000", "000000010000000100000001"},
{"000000010000000100000027", "000000010000000100000028"},
{"0000000100000001000000AB", "0000000100000001000000AC"},
{"0000000100000001000000FE", "0000000100000001000000FF"},
}
for _, tc := range cases {
result, err := calc.NextSegment(tc.input)
require.NoError(t, err)
assert.Equal(t, tc.expected, result, "input=%s", tc.input)
}
}
// NextSegment — wrap at 0x100 (256) for 16 MB segments
func Test_WalCalculator_NextSegment_DefaultSegmentSize_WrapsAt256(t *testing.T) {
calc := NewWalCalculator(mb16)
// SS=0xFF → SS=0x00, LL++
result, err := calc.NextSegment("0000000100000001000000FF")
require.NoError(t, err)
assert.Equal(t, "000000010000000200000000", result)
result, err = calc.NextSegment("0000000200000005000000FF")
require.NoError(t, err)
assert.Equal(t, "000000020000000600000000", result)
}
// NextSegment — wrap at 0x1000 (4096) for 1 MB segments
func Test_WalCalculator_NextSegment_1MbSegmentSize_WrapsAt4096(t *testing.T) {
calc := NewWalCalculator(mb1)
// segmentsPerXLogId = 0x100000000 / (1*1024*1024) = 4096 = 0x1000
// Last valid SS = 0x00000FFF
result, err := calc.NextSegment("000000010000000100000FFE")
require.NoError(t, err)
assert.Equal(t, "000000010000000100000FFF", result)
// wrap: SS=0x0FFF → SS=0, LL++
result, err = calc.NextSegment("000000010000000100000FFF")
require.NoError(t, err)
assert.Equal(t, "000000010000000200000000", result)
}
// NextSegment — wrap at 0x40 (64) for 64 MB segments
func Test_WalCalculator_NextSegment_64MbSegmentSize_WrapsAt64(t *testing.T) {
calc := NewWalCalculator(mb64)
// segmentsPerXLogId = 0x100000000 / (64*1024*1024) = 64 = 0x40
// Last valid SS = 0x0000003F
result, err := calc.NextSegment("00000001000000010000003E")
require.NoError(t, err)
assert.Equal(t, "00000001000000010000003F", result)
// wrap: SS=0x3F → SS=0, LL++
result, err = calc.NextSegment("00000001000000010000003F")
require.NoError(t, err)
assert.Equal(t, "000000010000000200000000", result)
}
// NextSegment — log file increment on wrap
func Test_WalCalculator_NextSegment_IncrementsLog_OnSegmentWrap(t *testing.T) {
calc := NewWalCalculator(mb16)
// LL=0x00000001, wraps to LL=0x00000002
result, err := calc.NextSegment("0000000100000001000000FF")
require.NoError(t, err)
assert.Equal(t, "000000010000000200000000", result)
// LL=0x0000000F, wraps to LL=0x00000010
result, err = calc.NextSegment("00000001000000FF000000FF")
require.NoError(t, err)
assert.Equal(t, "000000010000010000000000", result)
}
// NextSegment — timeline is preserved
func Test_WalCalculator_NextSegment_TimelinePreserved(t *testing.T) {
calc := NewWalCalculator(mb16)
result, err := calc.NextSegment("000000030000000100000005")
require.NoError(t, err)
assert.Equal(t, "000000030000000100000006", result)
}
// NextSegment — invalid input
func Test_WalCalculator_NextSegment_InvalidName_ReturnsError(t *testing.T) {
calc := NewWalCalculator(mb16)
cases := []string{
"",
"00000001000000010000000", // 23 chars
"0000000100000001000000001", // 25 chars
"00000001000000010000000G", // non-hex char
"short",
}
for _, name := range cases {
_, err := calc.NextSegment(name)
assert.Error(t, err, "expected error for input %q", name)
}
}
// IsValidSegmentName
func Test_WalCalculator_IsValidSegmentName_ValidName_ReturnsTrue(t *testing.T) {
calc := NewWalCalculator(mb16)
valid := []string{
"000000010000000100000000",
"0000000100000001000000FF",
"FFFFFFFFFFFFFFFFFFFFFFFF",
"000000010000000100000027",
"0000000200000005000000AB",
}
for _, name := range valid {
assert.True(t, calc.IsValidSegmentName(name), "expected valid for %q", name)
}
}
func Test_WalCalculator_IsValidSegmentName_TooShort_ReturnsFalse(t *testing.T) {
calc := NewWalCalculator(mb16)
assert.False(t, calc.IsValidSegmentName("00000001000000010000000")) // 23 chars
assert.False(t, calc.IsValidSegmentName(""))
}
func Test_WalCalculator_IsValidSegmentName_TooLong_ReturnsFalse(t *testing.T) {
calc := NewWalCalculator(mb16)
assert.False(t, calc.IsValidSegmentName("0000000100000001000000001")) // 25 chars
}
func Test_WalCalculator_IsValidSegmentName_NonHex_ReturnsFalse(t *testing.T) {
calc := NewWalCalculator(mb16)
assert.False(t, calc.IsValidSegmentName("00000001000000010000000G"))
assert.False(t, calc.IsValidSegmentName("00000001000000010000000Z"))
assert.False(t, calc.IsValidSegmentName("000000010000000100000 00"))
}
// Compare
func Test_WalCalculator_Compare_ReturnsCorrectOrdering(t *testing.T) {
calc := NewWalCalculator(mb16)
cases := []struct {
a string
b string
expected int
}{
// equal
{"000000010000000100000001", "000000010000000100000001", 0},
// segment ordering
{"000000010000000100000001", "000000010000000100000002", -1},
{"000000010000000100000002", "000000010000000100000001", 1},
// log ordering
{"000000010000000100000000", "000000010000000200000000", -1},
{"000000010000000200000000", "000000010000000100000000", 1},
// timeline ordering
{"000000010000000100000000", "000000020000000100000000", -1},
{"000000020000000100000000", "000000010000000100000000", 1},
// across wrap boundary: log 1, seg 255 < log 2, seg 0
{"0000000100000001000000FF", "000000010000000200000000", -1},
}
for _, tc := range cases {
result, err := calc.Compare(tc.a, tc.b)
require.NoError(t, err)
assert.Equal(t, tc.expected, result, "Compare(%s, %s)", tc.a, tc.b)
}
}
func Test_WalCalculator_Compare_InvalidInput_ReturnsError(t *testing.T) {
calc := NewWalCalculator(mb16)
_, err := calc.Compare("invalid", "000000010000000100000001")
assert.Error(t, err)
_, err = calc.Compare("000000010000000100000001", "invalid")
assert.Error(t, err)
}
// Default segment size via zero input
func Test_WalCalculator_NewWalCalculator_ZeroSize_UsesDefault16MB(t *testing.T) {
calc := NewWalCalculator(0)
assert.Equal(t, int64(mb16), calc.segmentSizeBytes)
assert.Equal(t, uint64(256), calc.segmentsPerXLogId)
}

View File

@@ -0,0 +1,72 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE databases
ADD COLUMN agent_token TEXT,
ADD COLUMN is_agent_token_generated BOOLEAN NOT NULL DEFAULT FALSE;
CREATE UNIQUE INDEX idx_databases_agent_token ON databases (agent_token) WHERE agent_token IS NOT NULL;
ALTER TABLE postgresql_databases
ADD COLUMN backup_type TEXT NOT NULL DEFAULT 'PG_DUMP';
ALTER TABLE postgresql_databases
ALTER COLUMN host DROP NOT NULL,
ALTER COLUMN port DROP NOT NULL,
ALTER COLUMN username DROP NOT NULL,
ALTER COLUMN password DROP NOT NULL;
ALTER TABLE backups
ADD COLUMN pg_wal_backup_type TEXT,
ADD COLUMN pg_wal_start_segment TEXT,
ADD COLUMN pg_wal_stop_segment TEXT,
ADD COLUMN pg_version TEXT,
ADD COLUMN pg_wal_segment_name TEXT;
CREATE INDEX idx_backups_pg_wal_segment_name
ON backups (database_id, pg_wal_segment_name)
WHERE pg_wal_segment_name IS NOT NULL;
CREATE INDEX idx_backups_pg_wal_backup_type_created
ON backups (database_id, pg_wal_backup_type, created_at DESC)
WHERE pg_wal_backup_type IS NOT NULL;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
DROP INDEX IF EXISTS idx_backups_pg_wal_segment_name;
DROP INDEX IF EXISTS idx_backups_pg_wal_backup_type_created;
ALTER TABLE backups
DROP COLUMN pg_wal_backup_type,
DROP COLUMN pg_wal_start_segment,
DROP COLUMN pg_wal_stop_segment,
DROP COLUMN pg_version,
DROP COLUMN pg_wal_segment_name;
UPDATE postgresql_databases
SET host = 'localhost' WHERE host IS NULL OR host = '';
UPDATE postgresql_databases
SET port = 5432 WHERE port IS NULL OR port = 0;
UPDATE postgresql_databases
SET username = 'postgres' WHERE username IS NULL OR username = '';
UPDATE postgresql_databases
SET password = 'stubpassword' WHERE password IS NULL OR password = '';
ALTER TABLE postgresql_databases
DROP COLUMN backup_type;
ALTER TABLE postgresql_databases
ALTER COLUMN host SET NOT NULL,
ALTER COLUMN port SET NOT NULL,
ALTER COLUMN username SET NOT NULL,
ALTER COLUMN password SET NOT NULL;
DROP INDEX IF EXISTS idx_databases_agent_token;
ALTER TABLE databases
DROP COLUMN agent_token,
DROP COLUMN is_agent_token_generated;
-- +goose StatementEnd