mirror of
https://github.com/databasus/databasus.git
synced 2026-04-06 00:32:03 +02:00
FEATURE (backups): Add WAL API
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -12,3 +12,4 @@ node_modules/
|
|||||||
.DS_Store
|
.DS_Store
|
||||||
/scripts
|
/scripts
|
||||||
.vscode/settings.json
|
.vscode/settings.json
|
||||||
|
.claude
|
||||||
@@ -14,9 +14,10 @@ import (
|
|||||||
|
|
||||||
"databasus-backend/internal/config"
|
"databasus-backend/internal/config"
|
||||||
"databasus-backend/internal/features/audit_logs"
|
"databasus-backend/internal/features/audit_logs"
|
||||||
"databasus-backend/internal/features/backups/backups"
|
|
||||||
"databasus-backend/internal/features/backups/backups/backuping"
|
"databasus-backend/internal/features/backups/backups/backuping"
|
||||||
|
backups_controllers "databasus-backend/internal/features/backups/backups/controllers"
|
||||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||||
|
backups_services "databasus-backend/internal/features/backups/backups/services"
|
||||||
backups_config "databasus-backend/internal/features/backups/config"
|
backups_config "databasus-backend/internal/features/backups/config"
|
||||||
"databasus-backend/internal/features/databases"
|
"databasus-backend/internal/features/databases"
|
||||||
"databasus-backend/internal/features/disk"
|
"databasus-backend/internal/features/disk"
|
||||||
@@ -209,7 +210,9 @@ func setUpRoutes(r *gin.Engine) {
|
|||||||
userController := users_controllers.GetUserController()
|
userController := users_controllers.GetUserController()
|
||||||
userController.RegisterRoutes(v1)
|
userController.RegisterRoutes(v1)
|
||||||
system_healthcheck.GetHealthcheckController().RegisterRoutes(v1)
|
system_healthcheck.GetHealthcheckController().RegisterRoutes(v1)
|
||||||
backups.GetBackupController().RegisterPublicRoutes(v1)
|
backups_controllers.GetBackupController().RegisterPublicRoutes(v1)
|
||||||
|
backups_controllers.GetPostgresWalBackupController().RegisterRoutes(v1)
|
||||||
|
databases.GetDatabaseController().RegisterPublicRoutes(v1)
|
||||||
|
|
||||||
// Setup auth middleware
|
// Setup auth middleware
|
||||||
userService := users_services.GetUserService()
|
userService := users_services.GetUserService()
|
||||||
@@ -226,7 +229,7 @@ func setUpRoutes(r *gin.Engine) {
|
|||||||
notifiers.GetNotifierController().RegisterRoutes(protected)
|
notifiers.GetNotifierController().RegisterRoutes(protected)
|
||||||
storages.GetStorageController().RegisterRoutes(protected)
|
storages.GetStorageController().RegisterRoutes(protected)
|
||||||
databases.GetDatabaseController().RegisterRoutes(protected)
|
databases.GetDatabaseController().RegisterRoutes(protected)
|
||||||
backups.GetBackupController().RegisterRoutes(protected)
|
backups_controllers.GetBackupController().RegisterRoutes(protected)
|
||||||
restores.GetRestoreController().RegisterRoutes(protected)
|
restores.GetRestoreController().RegisterRoutes(protected)
|
||||||
healthcheck_config.GetHealthcheckConfigController().RegisterRoutes(protected)
|
healthcheck_config.GetHealthcheckConfigController().RegisterRoutes(protected)
|
||||||
healthcheck_attempt.GetHealthcheckAttemptController().RegisterRoutes(protected)
|
healthcheck_attempt.GetHealthcheckAttemptController().RegisterRoutes(protected)
|
||||||
@@ -238,7 +241,7 @@ func setUpRoutes(r *gin.Engine) {
|
|||||||
|
|
||||||
func setUpDependencies() {
|
func setUpDependencies() {
|
||||||
databases.SetupDependencies()
|
databases.SetupDependencies()
|
||||||
backups.SetupDependencies()
|
backups_services.SetupDependencies()
|
||||||
restores.SetupDependencies()
|
restores.SetupDependencies()
|
||||||
healthcheck_config.SetupDependencies()
|
healthcheck_config.SetupDependencies()
|
||||||
audit_logs.SetupDependencies()
|
audit_logs.SetupDependencies()
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ import (
|
|||||||
backups_config "databasus-backend/internal/features/backups/config"
|
backups_config "databasus-backend/internal/features/backups/config"
|
||||||
"databasus-backend/internal/features/databases"
|
"databasus-backend/internal/features/databases"
|
||||||
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||||
files_utils "databasus-backend/internal/util/files"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -171,13 +170,7 @@ func (s *BackupsScheduler) StartBackup(database *databases.Database, isCallNotif
|
|||||||
timestamp := time.Now().UTC()
|
timestamp := time.Now().UTC()
|
||||||
|
|
||||||
backup := &backups_core.Backup{
|
backup := &backups_core.Backup{
|
||||||
ID: backupID,
|
ID: backupID,
|
||||||
FileName: fmt.Sprintf(
|
|
||||||
"%s-%s-%s",
|
|
||||||
files_utils.SanitizeFilename(database.Name),
|
|
||||||
timestamp.Format("20060102-150405"),
|
|
||||||
backupID.String(),
|
|
||||||
),
|
|
||||||
DatabaseID: backupConfig.DatabaseID,
|
DatabaseID: backupConfig.DatabaseID,
|
||||||
StorageID: *backupConfig.StorageID,
|
StorageID: *backupConfig.StorageID,
|
||||||
Status: backups_core.BackupStatusInProgress,
|
Status: backups_core.BackupStatusInProgress,
|
||||||
@@ -185,6 +178,8 @@ func (s *BackupsScheduler) StartBackup(database *databases.Database, isCallNotif
|
|||||||
CreatedAt: timestamp,
|
CreatedAt: timestamp,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
backup.GenerateFilename(database.Name)
|
||||||
|
|
||||||
if err := s.backupRepository.Save(backup); err != nil {
|
if err := s.backupRepository.Save(backup); err != nil {
|
||||||
s.logger.Error(
|
s.logger.Error(
|
||||||
"Failed to save backup",
|
"Failed to save backup",
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
package backups
|
package backups_controllers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||||
|
backups_dto "databasus-backend/internal/features/backups/backups/dto"
|
||||||
|
backups_services "databasus-backend/internal/features/backups/backups/services"
|
||||||
"databasus-backend/internal/features/databases"
|
"databasus-backend/internal/features/databases"
|
||||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||||
files_utils "databasus-backend/internal/util/files"
|
files_utils "databasus-backend/internal/util/files"
|
||||||
@@ -17,7 +19,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type BackupController struct {
|
type BackupController struct {
|
||||||
backupService *BackupService
|
backupService *backups_services.BackupService
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *BackupController) RegisterRoutes(router *gin.RouterGroup) {
|
func (c *BackupController) RegisterRoutes(router *gin.RouterGroup) {
|
||||||
@@ -42,7 +44,7 @@ func (c *BackupController) RegisterPublicRoutes(router *gin.RouterGroup) {
|
|||||||
// @Param database_id query string true "Database ID"
|
// @Param database_id query string true "Database ID"
|
||||||
// @Param limit query int false "Number of items per page" default(10)
|
// @Param limit query int false "Number of items per page" default(10)
|
||||||
// @Param offset query int false "Offset for pagination" default(0)
|
// @Param offset query int false "Offset for pagination" default(0)
|
||||||
// @Success 200 {object} GetBackupsResponse
|
// @Success 200 {object} backups_dto.GetBackupsResponse
|
||||||
// @Failure 400
|
// @Failure 400
|
||||||
// @Failure 401
|
// @Failure 401
|
||||||
// @Failure 500
|
// @Failure 500
|
||||||
@@ -54,7 +56,7 @@ func (c *BackupController) GetBackups(ctx *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var request GetBackupsRequest
|
var request backups_dto.GetBackupsRequest
|
||||||
if err := ctx.ShouldBindQuery(&request); err != nil {
|
if err := ctx.ShouldBindQuery(&request); err != nil {
|
||||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
@@ -81,7 +83,7 @@ func (c *BackupController) GetBackups(ctx *gin.Context) {
|
|||||||
// @Tags backups
|
// @Tags backups
|
||||||
// @Accept json
|
// @Accept json
|
||||||
// @Produce json
|
// @Produce json
|
||||||
// @Param request body MakeBackupRequest true "Backup creation data"
|
// @Param request body backups_dto.MakeBackupRequest true "Backup creation data"
|
||||||
// @Success 200 {object} map[string]string
|
// @Success 200 {object} map[string]string
|
||||||
// @Failure 400
|
// @Failure 400
|
||||||
// @Failure 401
|
// @Failure 401
|
||||||
@@ -94,7 +96,7 @@ func (c *BackupController) MakeBackup(ctx *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var request MakeBackupRequest
|
var request backups_dto.MakeBackupRequest
|
||||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
@@ -310,10 +312,6 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
|
|||||||
c.backupService.WriteAuditLogForDownload(downloadToken.UserID, backup, database)
|
c.backupService.WriteAuditLogForDownload(downloadToken.UserID, backup, database)
|
||||||
}
|
}
|
||||||
|
|
||||||
type MakeBackupRequest struct {
|
|
||||||
DatabaseID uuid.UUID `json:"database_id" binding:"required"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *BackupController) generateBackupFilename(
|
func (c *BackupController) generateBackupFilename(
|
||||||
backup *backups_core.Backup,
|
backup *backups_core.Backup,
|
||||||
database *databases.Database,
|
database *databases.Database,
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package backups
|
package backups_controllers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -24,11 +24,14 @@ import (
|
|||||||
backups_common "databasus-backend/internal/features/backups/backups/common"
|
backups_common "databasus-backend/internal/features/backups/backups/common"
|
||||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||||
|
backups_dto "databasus-backend/internal/features/backups/backups/dto"
|
||||||
|
backups_services "databasus-backend/internal/features/backups/backups/services"
|
||||||
backups_config "databasus-backend/internal/features/backups/config"
|
backups_config "databasus-backend/internal/features/backups/config"
|
||||||
"databasus-backend/internal/features/databases"
|
"databasus-backend/internal/features/databases"
|
||||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||||
"databasus-backend/internal/features/storages"
|
"databasus-backend/internal/features/storages"
|
||||||
local_storage "databasus-backend/internal/features/storages/models/local"
|
local_storage "databasus-backend/internal/features/storages/models/local"
|
||||||
|
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||||
users_dto "databasus-backend/internal/features/users/dto"
|
users_dto "databasus-backend/internal/features/users/dto"
|
||||||
users_enums "databasus-backend/internal/features/users/enums"
|
users_enums "databasus-backend/internal/features/users/enums"
|
||||||
users_services "databasus-backend/internal/features/users/services"
|
users_services "databasus-backend/internal/features/users/services"
|
||||||
@@ -119,7 +122,7 @@ func Test_GetBackups_PermissionsEnforced(t *testing.T) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
if tt.expectSuccess {
|
if tt.expectSuccess {
|
||||||
var response GetBackupsResponse
|
var response backups_dto.GetBackupsResponse
|
||||||
err := json.Unmarshal(testResp.Body, &response)
|
err := json.Unmarshal(testResp.Body, &response)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.GreaterOrEqual(t, len(response.Backups), 1)
|
assert.GreaterOrEqual(t, len(response.Backups), 1)
|
||||||
@@ -214,7 +217,7 @@ func Test_CreateBackup_PermissionsEnforced(t *testing.T) {
|
|||||||
testUserToken = nonMember.Token
|
testUserToken = nonMember.Token
|
||||||
}
|
}
|
||||||
|
|
||||||
request := MakeBackupRequest{DatabaseID: database.ID}
|
request := backups_dto.MakeBackupRequest{DatabaseID: database.ID}
|
||||||
testResp := test_utils.MakePostRequest(
|
testResp := test_utils.MakePostRequest(
|
||||||
t,
|
t,
|
||||||
router,
|
router,
|
||||||
@@ -245,7 +248,7 @@ func Test_CreateBackup_AuditLogWritten(t *testing.T) {
|
|||||||
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
||||||
enableBackupForDatabase(database.ID)
|
enableBackupForDatabase(database.ID)
|
||||||
|
|
||||||
request := MakeBackupRequest{DatabaseID: database.ID}
|
request := backups_dto.MakeBackupRequest{DatabaseID: database.ID}
|
||||||
test_utils.MakePostRequest(
|
test_utils.MakePostRequest(
|
||||||
t,
|
t,
|
||||||
router,
|
router,
|
||||||
@@ -373,7 +376,7 @@ func Test_DeleteBackup_PermissionsEnforced(t *testing.T) {
|
|||||||
ownerUser, err := userService.GetUserFromToken(owner.Token)
|
ownerUser, err := userService.GetUserFromToken(owner.Token)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
response, err := GetBackupService().GetBackups(ownerUser, database.ID, 10, 0)
|
response, err := backups_services.GetBackupService().GetBackups(ownerUser, database.ID, 10, 0)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, 0, len(response.Backups))
|
assert.Equal(t, 0, len(response.Backups))
|
||||||
}
|
}
|
||||||
@@ -999,7 +1002,7 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// Register a cancellable context for the backup
|
// Register a cancellable context for the backup
|
||||||
GetBackupService().taskCancelManager.RegisterTask(backup.ID, func() {})
|
task_cancellation.GetTaskCancelManager().RegisterTask(backup.ID, func() {})
|
||||||
|
|
||||||
resp := test_utils.MakePostRequest(
|
resp := test_utils.MakePostRequest(
|
||||||
t,
|
t,
|
||||||
@@ -1091,7 +1094,7 @@ func Test_ConcurrentDownloadPrevention(t *testing.T) {
|
|||||||
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
service := GetBackupService()
|
service := backups_services.GetBackupService()
|
||||||
if !service.IsDownloadInProgress(owner.UserID) {
|
if !service.IsDownloadInProgress(owner.UserID) {
|
||||||
t.Log("Warning: First download completed before we could test concurrency")
|
t.Log("Warning: First download completed before we could test concurrency")
|
||||||
<-downloadComplete
|
<-downloadComplete
|
||||||
@@ -1192,7 +1195,7 @@ func Test_GenerateDownloadToken_BlockedWhenDownloadInProgress(t *testing.T) {
|
|||||||
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
service := GetBackupService()
|
service := backups_services.GetBackupService()
|
||||||
if !service.IsDownloadInProgress(owner.UserID) {
|
if !service.IsDownloadInProgress(owner.UserID) {
|
||||||
t.Log("Warning: First download completed before we could test token generation blocking")
|
t.Log("Warning: First download completed before we could test token generation blocking")
|
||||||
<-downloadComplete
|
<-downloadComplete
|
||||||
@@ -1268,7 +1271,7 @@ func Test_MakeBackup_VerifyBackupAndMetadataFilesExistInStorage(t *testing.T) {
|
|||||||
initialBackups, err := backupRepo.FindByDatabaseID(database.ID)
|
initialBackups, err := backupRepo.FindByDatabaseID(database.ID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
request := MakeBackupRequest{DatabaseID: database.ID}
|
request := backups_dto.MakeBackupRequest{DatabaseID: database.ID}
|
||||||
test_utils.MakePostRequest(
|
test_utils.MakePostRequest(
|
||||||
t,
|
t,
|
||||||
router,
|
router,
|
||||||
@@ -1502,7 +1505,7 @@ func createTestBackup(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func createExpiredDownloadToken(backupID, userID uuid.UUID) string {
|
func createExpiredDownloadToken(backupID, userID uuid.UUID) string {
|
||||||
tokenService := GetBackupService().downloadTokenService
|
tokenService := backups_download.GetDownloadTokenService()
|
||||||
token, err := tokenService.Generate(backupID, userID)
|
token, err := tokenService.Generate(backupID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Sprintf("Failed to generate download token: %v", err))
|
panic(fmt.Sprintf("Failed to generate download token: %v", err))
|
||||||
@@ -1843,7 +1846,7 @@ func Test_DeleteBackup_RemovesBackupAndMetadataFilesFromDisk(t *testing.T) {
|
|||||||
initialBackups, err := backupRepo.FindByDatabaseID(database.ID)
|
initialBackups, err := backupRepo.FindByDatabaseID(database.ID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
request := MakeBackupRequest{DatabaseID: database.ID}
|
request := backups_dto.MakeBackupRequest{DatabaseID: database.ID}
|
||||||
test_utils.MakePostRequest(
|
test_utils.MakePostRequest(
|
||||||
t,
|
t,
|
||||||
router,
|
router,
|
||||||
23
backend/internal/features/backups/backups/controllers/di.go
Normal file
23
backend/internal/features/backups/backups/controllers/di.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package backups_controllers
|
||||||
|
|
||||||
|
import (
|
||||||
|
backups_services "databasus-backend/internal/features/backups/backups/services"
|
||||||
|
"databasus-backend/internal/features/databases"
|
||||||
|
)
|
||||||
|
|
||||||
|
var backupController = &BackupController{
|
||||||
|
backups_services.GetBackupService(),
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetBackupController() *BackupController {
|
||||||
|
return backupController
|
||||||
|
}
|
||||||
|
|
||||||
|
var postgresWalBackupController = &PostgreWalBackupController{
|
||||||
|
databases.GetDatabaseService(),
|
||||||
|
backups_services.GetWalService(),
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetPostgresWalBackupController() *PostgreWalBackupController {
|
||||||
|
return postgresWalBackupController
|
||||||
|
}
|
||||||
@@ -0,0 +1,291 @@
|
|||||||
|
package backups_controllers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||||
|
backups_dto "databasus-backend/internal/features/backups/backups/dto"
|
||||||
|
backups_services "databasus-backend/internal/features/backups/backups/services"
|
||||||
|
"databasus-backend/internal/features/databases"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PostgreWalBackupController handles WAL backup endpoints used by the databasus-cli agent.
|
||||||
|
// Authentication is via a plain agent token in the Authorization header (no Bearer prefix).
|
||||||
|
type PostgreWalBackupController struct {
|
||||||
|
databaseService *databases.DatabaseService
|
||||||
|
walService *backups_services.PostgreWalBackupService
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PostgreWalBackupController) RegisterRoutes(router *gin.RouterGroup) {
|
||||||
|
walRoutes := router.Group("/backups/postgres/wal")
|
||||||
|
|
||||||
|
walRoutes.GET("/next-full-backup-time", c.GetNextFullBackupTime)
|
||||||
|
walRoutes.POST("/error", c.ReportError)
|
||||||
|
walRoutes.POST("/upload", c.Upload)
|
||||||
|
walRoutes.GET("/restore/plan", c.GetRestorePlan)
|
||||||
|
walRoutes.GET("/restore/download", c.DownloadBackupFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNextFullBackupTime
|
||||||
|
// @Summary Get next full backup time
|
||||||
|
// @Description Returns the next scheduled full basebackup time for the authenticated database
|
||||||
|
// @Tags backups-wal
|
||||||
|
// @Produce json
|
||||||
|
// @Security AgentToken
|
||||||
|
// @Success 200 {object} backups_dto.GetNextFullBackupTimeResponse
|
||||||
|
// @Failure 401 {object} map[string]string
|
||||||
|
// @Failure 500 {object} map[string]string
|
||||||
|
// @Router /backups/postgres/wal/next-full-backup-time [get]
|
||||||
|
func (c *PostgreWalBackupController) GetNextFullBackupTime(ctx *gin.Context) {
|
||||||
|
database, err := c.getDatabase(ctx)
|
||||||
|
if err != nil {
|
||||||
|
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := c.walService.GetNextFullBackupTime(database)
|
||||||
|
if err != nil {
|
||||||
|
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.JSON(http.StatusOK, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReportError
|
||||||
|
// @Summary Report agent error
|
||||||
|
// @Description Records a fatal error from the agent against the database record and marks it as errored
|
||||||
|
// @Tags backups-wal
|
||||||
|
// @Accept json
|
||||||
|
// @Security AgentToken
|
||||||
|
// @Param request body backups_dto.ReportErrorRequest true "Error details"
|
||||||
|
// @Success 200
|
||||||
|
// @Failure 400 {object} map[string]string
|
||||||
|
// @Failure 401 {object} map[string]string
|
||||||
|
// @Failure 500 {object} map[string]string
|
||||||
|
// @Router /backups/postgres/wal/error [post]
|
||||||
|
func (c *PostgreWalBackupController) ReportError(ctx *gin.Context) {
|
||||||
|
database, err := c.getDatabase(ctx)
|
||||||
|
if err != nil {
|
||||||
|
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var request backups_dto.ReportErrorRequest
|
||||||
|
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||||
|
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.walService.ReportError(database, request.Error); err != nil {
|
||||||
|
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Status(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upload
|
||||||
|
// @Summary Stream upload a basebackup or WAL segment
|
||||||
|
// @Description Accepts a zstd-compressed binary stream and stores it in the database's configured storage.
|
||||||
|
// The server generates the storage filename; agents do not control the destination path.
|
||||||
|
// For WAL segment uploads the server validates the WAL chain and returns 409 if a gap is detected
|
||||||
|
// or 400 if no full backup exists yet (agent should trigger a full basebackup in both cases).
|
||||||
|
// @Tags backups-wal
|
||||||
|
// @Accept application/octet-stream
|
||||||
|
// @Produce json
|
||||||
|
// @Security AgentToken
|
||||||
|
// @Param X-Upload-Type header string true "Upload type" Enums(basebackup, wal)
|
||||||
|
// @Param X-Wal-Segment-Name header string false "24-hex WAL segment identifier (required for wal uploads, e.g. 0000000100000001000000AB)"
|
||||||
|
// @Param X-Wal-Segment-Size header int false "WAL segment size in bytes reported by the PostgreSQL instance (default: 16777216)"
|
||||||
|
// @Param fullBackupWalStartSegment query string false "First WAL segment needed to make the basebackup consistent (required for basebackup uploads)"
|
||||||
|
// @Param fullBackupWalStopSegment query string false "Last WAL segment included in the basebackup (required for basebackup uploads)"
|
||||||
|
// @Success 204
|
||||||
|
// @Failure 400 {object} backups_dto.UploadGapResponse "No full backup exists (error: no_full_backup)"
|
||||||
|
// @Failure 401 {object} map[string]string
|
||||||
|
// @Failure 409 {object} backups_dto.UploadGapResponse "WAL chain gap detected (error: gap_detected)"
|
||||||
|
// @Failure 500 {object} map[string]string
|
||||||
|
// @Router /backups/postgres/wal/upload [post]
|
||||||
|
func (c *PostgreWalBackupController) Upload(ctx *gin.Context) {
|
||||||
|
database, err := c.getDatabase(ctx)
|
||||||
|
if err != nil {
|
||||||
|
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
uploadType := backups_core.PgWalUploadType(ctx.GetHeader("X-Upload-Type"))
|
||||||
|
if uploadType != backups_core.PgWalUploadTypeBasebackup &&
|
||||||
|
uploadType != backups_core.PgWalUploadTypeWal {
|
||||||
|
ctx.JSON(
|
||||||
|
http.StatusBadRequest,
|
||||||
|
gin.H{"error": "X-Upload-Type must be 'basebackup' or 'wal'"},
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
walSegmentName := ""
|
||||||
|
if uploadType == backups_core.PgWalUploadTypeWal {
|
||||||
|
walSegmentName = ctx.GetHeader("X-Wal-Segment-Name")
|
||||||
|
if walSegmentName == "" {
|
||||||
|
ctx.JSON(
|
||||||
|
http.StatusBadRequest,
|
||||||
|
gin.H{"error": "X-Wal-Segment-Name is required for wal uploads"},
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if uploadType == backups_core.PgWalUploadTypeBasebackup {
|
||||||
|
if ctx.Query("fullBackupWalStartSegment") == "" ||
|
||||||
|
ctx.Query("fullBackupWalStopSegment") == "" {
|
||||||
|
ctx.JSON(
|
||||||
|
http.StatusBadRequest,
|
||||||
|
gin.H{
|
||||||
|
"error": "fullBackupWalStartSegment and fullBackupWalStopSegment are required for basebackup uploads",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
walSegmentSizeBytes := int64(0)
|
||||||
|
if raw := ctx.GetHeader("X-Wal-Segment-Size"); raw != "" {
|
||||||
|
parsed, parseErr := strconv.ParseInt(raw, 10, 64)
|
||||||
|
if parseErr != nil || parsed <= 0 {
|
||||||
|
ctx.JSON(
|
||||||
|
http.StatusBadRequest,
|
||||||
|
gin.H{"error": "X-Wal-Segment-Size must be a positive integer"},
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
walSegmentSizeBytes = parsed
|
||||||
|
}
|
||||||
|
|
||||||
|
gapResp, uploadErr := c.walService.UploadWal(
|
||||||
|
ctx.Request.Context(),
|
||||||
|
database,
|
||||||
|
uploadType,
|
||||||
|
walSegmentName,
|
||||||
|
ctx.Query("fullBackupWalStartSegment"),
|
||||||
|
ctx.Query("fullBackupWalStopSegment"),
|
||||||
|
walSegmentSizeBytes,
|
||||||
|
ctx.Request.Body,
|
||||||
|
)
|
||||||
|
|
||||||
|
if uploadErr != nil {
|
||||||
|
ctx.JSON(http.StatusInternalServerError, gin.H{"error": uploadErr.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if gapResp != nil {
|
||||||
|
if gapResp.Error == "no_full_backup" {
|
||||||
|
ctx.JSON(http.StatusBadRequest, gapResp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.JSON(http.StatusConflict, gapResp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Status(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRestorePlan
|
||||||
|
// @Summary Get restore plan
|
||||||
|
// @Description Resolves the full backup and all required WAL segments needed for recovery. Validates the WAL chain is continuous.
|
||||||
|
// @Tags backups-wal
|
||||||
|
// @Produce json
|
||||||
|
// @Security AgentToken
|
||||||
|
// @Param backupId query string false "UUID of a specific full backup to restore from; defaults to the most recent"
|
||||||
|
// @Success 200 {object} backups_dto.GetRestorePlanResponse
|
||||||
|
// @Failure 400 {object} map[string]string "Broken WAL chain or no backups available"
|
||||||
|
// @Failure 401 {object} map[string]string
|
||||||
|
// @Failure 500 {object} map[string]string
|
||||||
|
// @Router /backups/postgres/wal/restore/plan [get]
|
||||||
|
func (c *PostgreWalBackupController) GetRestorePlan(ctx *gin.Context) {
|
||||||
|
database, err := c.getDatabase(ctx)
|
||||||
|
if err != nil {
|
||||||
|
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var backupID *uuid.UUID
|
||||||
|
if raw := ctx.Query("backupId"); raw != "" {
|
||||||
|
parsed, parseErr := uuid.Parse(raw)
|
||||||
|
if parseErr != nil {
|
||||||
|
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backupId format"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
backupID = &parsed
|
||||||
|
}
|
||||||
|
|
||||||
|
response, planErr, err := c.walService.GetRestorePlan(database, backupID)
|
||||||
|
if err != nil {
|
||||||
|
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if planErr != nil {
|
||||||
|
ctx.JSON(http.StatusBadRequest, planErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.JSON(http.StatusOK, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DownloadBackupFile
|
||||||
|
// @Summary Download a backup or WAL segment file for restore
|
||||||
|
// @Description Retrieves the backup file by ID (validated against the authenticated database), decrypts it server-side if encrypted, and streams the zstd-compressed result to the agent
|
||||||
|
// @Tags backups-wal
|
||||||
|
// @Produce application/octet-stream
|
||||||
|
// @Security AgentToken
|
||||||
|
// @Param backupId query string true "Backup ID from the restore plan response"
|
||||||
|
// @Success 200 {file} file
|
||||||
|
// @Failure 400 {object} map[string]string
|
||||||
|
// @Failure 401 {object} map[string]string
|
||||||
|
// @Router /backups/postgres/wal/restore/download [get]
|
||||||
|
func (c *PostgreWalBackupController) DownloadBackupFile(ctx *gin.Context) {
|
||||||
|
database, err := c.getDatabase(ctx)
|
||||||
|
if err != nil {
|
||||||
|
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
backupIDRaw := ctx.Query("backupId")
|
||||||
|
if backupIDRaw == "" {
|
||||||
|
ctx.JSON(http.StatusBadRequest, gin.H{"error": "backupId is required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
backupID, err := uuid.Parse(backupIDRaw)
|
||||||
|
if err != nil {
|
||||||
|
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backupId format"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
reader, err := c.walService.DownloadBackupFile(database, backupID)
|
||||||
|
if err != nil {
|
||||||
|
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() { _ = reader.Close() }()
|
||||||
|
|
||||||
|
ctx.Header("Content-Type", "application/octet-stream")
|
||||||
|
ctx.Status(http.StatusOK)
|
||||||
|
|
||||||
|
_, _ = io.Copy(ctx.Writer, reader)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PostgreWalBackupController) getDatabase(
|
||||||
|
ctx *gin.Context,
|
||||||
|
) (*databases.Database, error) {
|
||||||
|
token := ctx.GetHeader("Authorization")
|
||||||
|
return c.databaseService.GetDatabaseByAgentToken(token)
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
|||||||
package backups
|
package backups_controllers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
@@ -41,7 +41,7 @@ func WaitForBackupCompletion(
|
|||||||
deadline := time.Now().UTC().Add(timeout)
|
deadline := time.Now().UTC().Add(timeout)
|
||||||
|
|
||||||
for time.Now().UTC().Before(deadline) {
|
for time.Now().UTC().Before(deadline) {
|
||||||
backups, err := backupRepository.FindByDatabaseID(databaseID)
|
backups, err := backups_core.GetBackupRepository().FindByDatabaseID(databaseID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("WaitForBackupCompletion: error finding backups: %v", err)
|
t.Logf("WaitForBackupCompletion: error finding backups: %v", err)
|
||||||
time.Sleep(50 * time.Millisecond)
|
time.Sleep(50 * time.Millisecond)
|
||||||
7
backend/internal/features/backups/backups/core/di.go
Normal file
7
backend/internal/features/backups/backups/core/di.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
package backups_core
|
||||||
|
|
||||||
|
var backupRepository = &BackupRepository{}
|
||||||
|
|
||||||
|
func GetBackupRepository() *BackupRepository {
|
||||||
|
return backupRepository
|
||||||
|
}
|
||||||
@@ -8,3 +8,10 @@ const (
|
|||||||
BackupStatusFailed BackupStatus = "FAILED"
|
BackupStatusFailed BackupStatus = "FAILED"
|
||||||
BackupStatusCanceled BackupStatus = "CANCELED"
|
BackupStatusCanceled BackupStatus = "CANCELED"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type PgWalUploadType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
PgWalUploadTypeBasebackup PgWalUploadType = "basebackup"
|
||||||
|
PgWalUploadTypeWal PgWalUploadType = "wal"
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,12 +1,22 @@
|
|||||||
package backups_core
|
package backups_core
|
||||||
|
|
||||||
import (
|
import (
|
||||||
backups_config "databasus-backend/internal/features/backups/config"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
backups_config "databasus-backend/internal/features/backups/config"
|
||||||
|
files_utils "databasus-backend/internal/util/files"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type PgWalBackupType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
PgWalBackupTypeFullBackup PgWalBackupType = "PG_FULL_BACKUP"
|
||||||
|
PgWalBackupTypeWalSegment PgWalBackupType = "PG_WAL_SEGMENT"
|
||||||
|
)
|
||||||
|
|
||||||
type Backup struct {
|
type Backup struct {
|
||||||
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
|
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
|
||||||
FileName string `json:"fileName" gorm:"column:file_name;type:text;not null"`
|
FileName string `json:"fileName" gorm:"column:file_name;type:text;not null"`
|
||||||
@@ -26,5 +36,23 @@ type Backup struct {
|
|||||||
EncryptionIV *string `json:"-" gorm:"column:encryption_iv"`
|
EncryptionIV *string `json:"-" gorm:"column:encryption_iv"`
|
||||||
Encryption backups_config.BackupEncryption `json:"encryption" gorm:"column:encryption;type:text;not null;default:'NONE'"`
|
Encryption backups_config.BackupEncryption `json:"encryption" gorm:"column:encryption;type:text;not null;default:'NONE'"`
|
||||||
|
|
||||||
|
// Postgres WAL backup specific fields
|
||||||
|
PgWalBackupType *PgWalBackupType `json:"pgWalBackupType" gorm:"column:pg_wal_backup_type;type:text"`
|
||||||
|
PgFullBackupWalStartSegmentName *string `json:"pgFullBackupWalStartSegmentName" gorm:"column:pg_wal_start_segment;type:text"`
|
||||||
|
PgFullBackupWalStopSegmentName *string `json:"pgFullBackupWalStopSegmentName" gorm:"column:pg_wal_stop_segment;type:text"`
|
||||||
|
PgVersion *string `json:"pgVersion" gorm:"column:pg_version;type:text"`
|
||||||
|
PgWalSegmentName *string `json:"pgWalSegmentName" gorm:"column:pg_wal_segment_name;type:text"`
|
||||||
|
|
||||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
|
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *Backup) GenerateFilename(dbName string) {
|
||||||
|
timestamp := time.Now().UTC()
|
||||||
|
|
||||||
|
b.FileName = fmt.Sprintf(
|
||||||
|
"%s-%s-%s",
|
||||||
|
files_utils.SanitizeFilename(dbName),
|
||||||
|
timestamp.Format("20060102-150405"),
|
||||||
|
b.ID.String(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|||||||
@@ -245,3 +245,134 @@ func (r *BackupRepository) FindOldestByDatabaseExcludingInProgress(
|
|||||||
|
|
||||||
return backups, nil
|
return backups, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *BackupRepository) FindCompletedFullWalBackupByID(
|
||||||
|
databaseID uuid.UUID,
|
||||||
|
backupID uuid.UUID,
|
||||||
|
) (*Backup, error) {
|
||||||
|
var backup Backup
|
||||||
|
|
||||||
|
err := storage.
|
||||||
|
GetDb().
|
||||||
|
Where(
|
||||||
|
"database_id = ? AND id = ? AND pg_wal_backup_type = ? AND status = ?",
|
||||||
|
databaseID,
|
||||||
|
backupID,
|
||||||
|
PgWalBackupTypeFullBackup,
|
||||||
|
BackupStatusCompleted,
|
||||||
|
).
|
||||||
|
First(&backup).Error
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &backup, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *BackupRepository) FindCompletedWalSegmentsAfter(
|
||||||
|
databaseID uuid.UUID,
|
||||||
|
afterSegmentName string,
|
||||||
|
) ([]*Backup, error) {
|
||||||
|
var backups []*Backup
|
||||||
|
|
||||||
|
err := storage.
|
||||||
|
GetDb().
|
||||||
|
Where(
|
||||||
|
"database_id = ? AND pg_wal_backup_type = ? AND pg_wal_segment_name >= ? AND status = ?",
|
||||||
|
databaseID,
|
||||||
|
PgWalBackupTypeWalSegment,
|
||||||
|
afterSegmentName,
|
||||||
|
BackupStatusCompleted,
|
||||||
|
).
|
||||||
|
Order("pg_wal_segment_name ASC").
|
||||||
|
Find(&backups).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return backups, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *BackupRepository) FindLastCompletedFullWalBackupByDatabaseID(
|
||||||
|
databaseID uuid.UUID,
|
||||||
|
) (*Backup, error) {
|
||||||
|
var backup Backup
|
||||||
|
|
||||||
|
err := storage.
|
||||||
|
GetDb().
|
||||||
|
Where(
|
||||||
|
"database_id = ? AND pg_wal_backup_type = ? AND status = ?",
|
||||||
|
databaseID,
|
||||||
|
PgWalBackupTypeFullBackup,
|
||||||
|
BackupStatusCompleted,
|
||||||
|
).
|
||||||
|
Order("created_at DESC").
|
||||||
|
First(&backup).Error
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &backup, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *BackupRepository) FindWalSegmentByName(
|
||||||
|
databaseID uuid.UUID,
|
||||||
|
segmentName string,
|
||||||
|
) (*Backup, error) {
|
||||||
|
var backup Backup
|
||||||
|
|
||||||
|
err := storage.
|
||||||
|
GetDb().
|
||||||
|
Where(
|
||||||
|
"database_id = ? AND pg_wal_backup_type = ? AND pg_wal_segment_name = ?",
|
||||||
|
databaseID,
|
||||||
|
PgWalBackupTypeWalSegment,
|
||||||
|
segmentName,
|
||||||
|
).
|
||||||
|
First(&backup).Error
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &backup, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *BackupRepository) FindLastWalSegmentAfter(
|
||||||
|
databaseID uuid.UUID,
|
||||||
|
afterSegmentName string,
|
||||||
|
) (*Backup, error) {
|
||||||
|
var backup Backup
|
||||||
|
|
||||||
|
err := storage.
|
||||||
|
GetDb().
|
||||||
|
Where(
|
||||||
|
"database_id = ? AND pg_wal_backup_type = ? AND pg_wal_segment_name > ? AND status = ?",
|
||||||
|
databaseID,
|
||||||
|
PgWalBackupTypeWalSegment,
|
||||||
|
afterSegmentName,
|
||||||
|
BackupStatusCompleted,
|
||||||
|
).
|
||||||
|
Order("pg_wal_segment_name DESC").
|
||||||
|
First(&backup).Error
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &backup, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,29 +0,0 @@
|
|||||||
package backups
|
|
||||||
|
|
||||||
import (
|
|
||||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
|
||||||
"databasus-backend/internal/features/backups/backups/encryption"
|
|
||||||
"io"
|
|
||||||
)
|
|
||||||
|
|
||||||
type GetBackupsRequest struct {
|
|
||||||
DatabaseID string `form:"database_id" binding:"required"`
|
|
||||||
Limit int `form:"limit"`
|
|
||||||
Offset int `form:"offset"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GetBackupsResponse struct {
|
|
||||||
Backups []*backups_core.Backup `json:"backups"`
|
|
||||||
Total int64 `json:"total"`
|
|
||||||
Limit int `json:"limit"`
|
|
||||||
Offset int `json:"offset"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type DecryptionReaderCloser struct {
|
|
||||||
*encryption.DecryptionReader
|
|
||||||
BaseReader io.ReadCloser
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *DecryptionReaderCloser) Close() error {
|
|
||||||
return r.BaseReader.Close()
|
|
||||||
}
|
|
||||||
78
backend/internal/features/backups/backups/dto/dto.go
Normal file
78
backend/internal/features/backups/backups/dto/dto.go
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
package backups_dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||||
|
"databasus-backend/internal/features/backups/backups/encryption"
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GetBackupsRequest struct {
|
||||||
|
DatabaseID string `form:"database_id" binding:"required"`
|
||||||
|
Limit int `form:"limit"`
|
||||||
|
Offset int `form:"offset"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GetBackupsResponse struct {
|
||||||
|
Backups []*backups_core.Backup `json:"backups"`
|
||||||
|
Total int64 `json:"total"`
|
||||||
|
Limit int `json:"limit"`
|
||||||
|
Offset int `json:"offset"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DecryptionReaderCloser struct {
|
||||||
|
*encryption.DecryptionReader
|
||||||
|
BaseReader io.ReadCloser
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *DecryptionReaderCloser) Close() error {
|
||||||
|
return r.BaseReader.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
type MakeBackupRequest struct {
|
||||||
|
DatabaseID uuid.UUID `json:"database_id" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GetNextFullBackupTimeResponse struct {
|
||||||
|
NextFullBackupTime *time.Time `json:"nextFullBackupTime"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReportErrorRequest struct {
|
||||||
|
Error string `json:"error" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type UploadGapResponse struct {
|
||||||
|
Error string `json:"error"`
|
||||||
|
ExpectedSegmentName string `json:"expectedSegmentName"`
|
||||||
|
ReceivedSegmentName string `json:"receivedSegmentName"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RestorePlanFullBackup struct {
|
||||||
|
BackupID uuid.UUID `json:"id"`
|
||||||
|
FullBackupWalStartSegment string `json:"fullBackupWalStartSegment"`
|
||||||
|
FullBackupWalStopSegment string `json:"fullBackupWalStopSegment"`
|
||||||
|
PgVersion string `json:"pgVersion"`
|
||||||
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
|
SizeBytes int64 `json:"sizeBytes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RestorePlanWalSegment struct {
|
||||||
|
BackupID uuid.UUID `json:"backupId"`
|
||||||
|
SegmentName string `json:"segmentName"`
|
||||||
|
SizeBytes int64 `json:"sizeBytes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GetRestorePlanErrorResponse struct {
|
||||||
|
Error string `json:"error"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
LastContiguousSegment string `json:"lastContiguousSegment,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GetRestorePlanResponse struct {
|
||||||
|
FullBackup RestorePlanFullBackup `json:"fullBackup"`
|
||||||
|
WalSegments []RestorePlanWalSegment `json:"walSegments"`
|
||||||
|
TotalSizeBytes int64 `json:"totalSizeBytes"`
|
||||||
|
LatestAvailableSegment string `json:"latestAvailableSegment"`
|
||||||
|
}
|
||||||
@@ -0,0 +1,45 @@
|
|||||||
|
package encryption
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EncryptionSetup holds the result of setting up encryption for a backup stream.
|
||||||
|
type EncryptionSetup struct {
|
||||||
|
Writer *EncryptionWriter
|
||||||
|
SaltBase64 string
|
||||||
|
NonceBase64 string
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupEncryptionWriter generates salt/nonce, creates an EncryptionWriter, and
|
||||||
|
// returns the base64-encoded salt and nonce for storage on the backup record.
|
||||||
|
func SetupEncryptionWriter(
|
||||||
|
baseWriter io.Writer,
|
||||||
|
masterKey string,
|
||||||
|
backupID uuid.UUID,
|
||||||
|
) (*EncryptionSetup, error) {
|
||||||
|
salt, err := GenerateSalt()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate salt: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
nonce, err := GenerateNonce()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate nonce: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
encWriter, err := NewEncryptionWriter(baseWriter, masterKey, backupID, salt, nonce)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create encryption writer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &EncryptionSetup{
|
||||||
|
Writer: encWriter,
|
||||||
|
SaltBase64: base64.StdEncoding.EncodeToString(salt),
|
||||||
|
NonceBase64: base64.StdEncoding.EncodeToString(nonce),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
@@ -1,9 +1,6 @@
|
|||||||
package backups
|
package backups_services
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
|
|
||||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||||
"databasus-backend/internal/features/backups/backups/backuping"
|
"databasus-backend/internal/features/backups/backups/backuping"
|
||||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||||
@@ -18,16 +15,16 @@ import (
|
|||||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||||
"databasus-backend/internal/util/encryption"
|
"databasus-backend/internal/util/encryption"
|
||||||
"databasus-backend/internal/util/logger"
|
"databasus-backend/internal/util/logger"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
)
|
)
|
||||||
|
|
||||||
var backupRepository = &backups_core.BackupRepository{}
|
|
||||||
|
|
||||||
var taskCancelManager = task_cancellation.GetTaskCancelManager()
|
var taskCancelManager = task_cancellation.GetTaskCancelManager()
|
||||||
|
|
||||||
var backupService = &BackupService{
|
var backupService = &BackupService{
|
||||||
databases.GetDatabaseService(),
|
databases.GetDatabaseService(),
|
||||||
storages.GetStorageService(),
|
storages.GetStorageService(),
|
||||||
backupRepository,
|
backups_core.GetBackupRepository(),
|
||||||
notifiers.GetNotifierService(),
|
notifiers.GetNotifierService(),
|
||||||
notifiers.GetNotifierService(),
|
notifiers.GetNotifierService(),
|
||||||
backups_config.GetBackupConfigService(),
|
backups_config.GetBackupConfigService(),
|
||||||
@@ -44,16 +41,21 @@ var backupService = &BackupService{
|
|||||||
backuping.GetBackupCleaner(),
|
backuping.GetBackupCleaner(),
|
||||||
}
|
}
|
||||||
|
|
||||||
var backupController = &BackupController{
|
|
||||||
backupService: backupService,
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetBackupService() *BackupService {
|
func GetBackupService() *BackupService {
|
||||||
return backupService
|
return backupService
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetBackupController() *BackupController {
|
var walService = &PostgreWalBackupService{
|
||||||
return backupController
|
backups_config.GetBackupConfigService(),
|
||||||
|
backups_core.GetBackupRepository(),
|
||||||
|
encryption.GetFieldEncryptor(),
|
||||||
|
encryption_secrets.GetSecretKeyService(),
|
||||||
|
logger.GetLogger(),
|
||||||
|
backupService,
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetWalService() *PostgreWalBackupService {
|
||||||
|
return walService
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -0,0 +1,613 @@
|
|||||||
|
package backups_services
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||||
|
backups_dto "databasus-backend/internal/features/backups/backups/dto"
|
||||||
|
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
|
||||||
|
backups_config "databasus-backend/internal/features/backups/config"
|
||||||
|
"databasus-backend/internal/features/databases"
|
||||||
|
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||||
|
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
|
||||||
|
util_encryption "databasus-backend/internal/util/encryption"
|
||||||
|
util_wal "databasus-backend/internal/util/wal"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PostgreWalBackupService handles WAL segment and basebackup uploads from the databasus-cli agent.
|
||||||
|
type PostgreWalBackupService struct {
|
||||||
|
backupConfigService *backups_config.BackupConfigService
|
||||||
|
backupRepository *backups_core.BackupRepository
|
||||||
|
fieldEncryptor util_encryption.FieldEncryptor
|
||||||
|
secretKeyService *encryption_secrets.SecretKeyService
|
||||||
|
logger *slog.Logger
|
||||||
|
backupService *BackupService
|
||||||
|
}
|
||||||
|
|
||||||
|
// UploadWal accepts a streaming WAL segment or basebackup upload from the agent.
|
||||||
|
// For WAL segments it validates the WAL chain before accepting. Returns an UploadGapResponse
|
||||||
|
// (409) when the chain is broken so the agent knows to trigger a full basebackup.
|
||||||
|
func (s *PostgreWalBackupService) UploadWal(
|
||||||
|
ctx context.Context,
|
||||||
|
database *databases.Database,
|
||||||
|
uploadType backups_core.PgWalUploadType,
|
||||||
|
walSegmentName string,
|
||||||
|
fullBackupWalStartSegment string,
|
||||||
|
fullBackupWalStopSegment string,
|
||||||
|
walSegmentSizeBytes int64,
|
||||||
|
body io.Reader,
|
||||||
|
) (*backups_dto.UploadGapResponse, error) {
|
||||||
|
if err := s.validateWalBackupType(database); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if uploadType == backups_core.PgWalUploadTypeBasebackup {
|
||||||
|
if fullBackupWalStartSegment == "" || fullBackupWalStopSegment == "" {
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"fullBackupWalStartSegment and fullBackupWalStopSegment are required for basebackup uploads",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get backup config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if backupConfig.Storage == nil {
|
||||||
|
return nil, fmt.Errorf("no storage configured for database %s", database.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if uploadType == backups_core.PgWalUploadTypeWal {
|
||||||
|
// Idempotency: check before chain validation so a successful re-upload is
|
||||||
|
// not misidentified as a gap.
|
||||||
|
existing, err := s.backupRepository.FindWalSegmentByName(database.ID, walSegmentName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to check for duplicate WAL segment: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if existing != nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
gapResp, err := s.validateWalChain(database.ID, walSegmentName, walSegmentSizeBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if gapResp != nil {
|
||||||
|
return gapResp, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
backup := s.createBackupRecord(
|
||||||
|
database.ID,
|
||||||
|
backupConfig.Storage.ID,
|
||||||
|
uploadType,
|
||||||
|
database.Name,
|
||||||
|
walSegmentName,
|
||||||
|
fullBackupWalStartSegment,
|
||||||
|
fullBackupWalStopSegment,
|
||||||
|
backupConfig.Encryption,
|
||||||
|
)
|
||||||
|
|
||||||
|
if err := s.backupRepository.Save(backup); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create backup record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sizeBytes, streamErr := s.streamToStorage(ctx, backup, backupConfig, body)
|
||||||
|
if streamErr != nil {
|
||||||
|
errMsg := streamErr.Error()
|
||||||
|
s.markFailed(backup, errMsg)
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("upload failed: %w", streamErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.markCompleted(backup, sizeBytes)
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PostgreWalBackupService) GetRestorePlan(
|
||||||
|
database *databases.Database,
|
||||||
|
backupID *uuid.UUID,
|
||||||
|
) (*backups_dto.GetRestorePlanResponse, *backups_dto.GetRestorePlanErrorResponse, error) {
|
||||||
|
if err := s.validateWalBackupType(database); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
fullBackup, err := s.resolveFullBackup(database.ID, backupID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if fullBackup == nil {
|
||||||
|
msg := "no full backups available for this database"
|
||||||
|
if backupID != nil {
|
||||||
|
msg = fmt.Sprintf("full backup %s not found or not completed", backupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, &backups_dto.GetRestorePlanErrorResponse{
|
||||||
|
Error: "no_backups",
|
||||||
|
Message: msg,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
startSegment := ""
|
||||||
|
if fullBackup.PgFullBackupWalStartSegmentName != nil {
|
||||||
|
startSegment = *fullBackup.PgFullBackupWalStartSegmentName
|
||||||
|
}
|
||||||
|
|
||||||
|
walSegments, err := s.backupRepository.FindCompletedWalSegmentsAfter(database.ID, startSegment)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("failed to query WAL segments: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
chainErr := s.validateRestoreWalChain(fullBackup, walSegments)
|
||||||
|
if chainErr != nil {
|
||||||
|
return nil, chainErr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fullBackupSizeBytes := int64(fullBackup.BackupSizeMb * 1024 * 1024)
|
||||||
|
|
||||||
|
pgVersion := ""
|
||||||
|
if fullBackup.PgVersion != nil {
|
||||||
|
pgVersion = *fullBackup.PgVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
stopSegment := ""
|
||||||
|
if fullBackup.PgFullBackupWalStopSegmentName != nil {
|
||||||
|
stopSegment = *fullBackup.PgFullBackupWalStopSegmentName
|
||||||
|
}
|
||||||
|
|
||||||
|
response := &backups_dto.GetRestorePlanResponse{
|
||||||
|
FullBackup: backups_dto.RestorePlanFullBackup{
|
||||||
|
BackupID: fullBackup.ID,
|
||||||
|
FullBackupWalStartSegment: startSegment,
|
||||||
|
FullBackupWalStopSegment: stopSegment,
|
||||||
|
PgVersion: pgVersion,
|
||||||
|
CreatedAt: fullBackup.CreatedAt,
|
||||||
|
SizeBytes: fullBackupSizeBytes,
|
||||||
|
},
|
||||||
|
TotalSizeBytes: fullBackupSizeBytes,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, seg := range walSegments {
|
||||||
|
segName := ""
|
||||||
|
if seg.PgWalSegmentName != nil {
|
||||||
|
segName = *seg.PgWalSegmentName
|
||||||
|
}
|
||||||
|
|
||||||
|
segSizeBytes := int64(seg.BackupSizeMb * 1024 * 1024)
|
||||||
|
|
||||||
|
response.WalSegments = append(response.WalSegments, backups_dto.RestorePlanWalSegment{
|
||||||
|
BackupID: seg.ID,
|
||||||
|
SegmentName: segName,
|
||||||
|
SizeBytes: segSizeBytes,
|
||||||
|
})
|
||||||
|
|
||||||
|
response.TotalSizeBytes += segSizeBytes
|
||||||
|
response.LatestAvailableSegment = segName
|
||||||
|
}
|
||||||
|
|
||||||
|
return response, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DownloadBackupFile returns a reader for a backup file belonging to the given database.
|
||||||
|
// Decryption is handled transparently if the backup is encrypted.
|
||||||
|
func (s *PostgreWalBackupService) DownloadBackupFile(
|
||||||
|
database *databases.Database,
|
||||||
|
backupID uuid.UUID,
|
||||||
|
) (io.ReadCloser, error) {
|
||||||
|
if err := s.validateWalBackupType(database); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
backup, err := s.backupRepository.FindByID(backupID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("backup not found: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if backup.DatabaseID != database.ID {
|
||||||
|
return nil, fmt.Errorf("backup does not belong to this database")
|
||||||
|
}
|
||||||
|
|
||||||
|
if backup.Status != backups_core.BackupStatusCompleted {
|
||||||
|
return nil, fmt.Errorf("backup is not completed")
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.backupService.GetBackupReader(backupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PostgreWalBackupService) validateWalChain(
|
||||||
|
databaseID uuid.UUID,
|
||||||
|
incomingSegment string,
|
||||||
|
walSegmentSizeBytes int64,
|
||||||
|
) (*backups_dto.UploadGapResponse, error) {
|
||||||
|
fullBackup, err := s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(databaseID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query full backup: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// No full backup exists yet: cannot accept WAL segments without a chain anchor.
|
||||||
|
if fullBackup == nil || fullBackup.PgFullBackupWalStopSegmentName == nil {
|
||||||
|
return &backups_dto.UploadGapResponse{
|
||||||
|
Error: "no_full_backup",
|
||||||
|
ExpectedSegmentName: "",
|
||||||
|
ReceivedSegmentName: incomingSegment,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
stopSegment := *fullBackup.PgFullBackupWalStopSegmentName
|
||||||
|
|
||||||
|
lastWal, err := s.backupRepository.FindLastWalSegmentAfter(databaseID, stopSegment)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query last WAL segment: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
walCalculator := util_wal.NewWalCalculator(walSegmentSizeBytes)
|
||||||
|
|
||||||
|
var chainTail string
|
||||||
|
if lastWal != nil && lastWal.PgWalSegmentName != nil {
|
||||||
|
chainTail = *lastWal.PgWalSegmentName
|
||||||
|
} else {
|
||||||
|
chainTail = stopSegment
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedNext, err := walCalculator.NextSegment(chainTail)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("WAL arithmetic failed for %q: %w", chainTail, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if incomingSegment != expectedNext {
|
||||||
|
return &backups_dto.UploadGapResponse{
|
||||||
|
Error: "gap_detected",
|
||||||
|
ExpectedSegmentName: expectedNext,
|
||||||
|
ReceivedSegmentName: incomingSegment,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PostgreWalBackupService) createBackupRecord(
|
||||||
|
databaseID uuid.UUID,
|
||||||
|
storageID uuid.UUID,
|
||||||
|
uploadType backups_core.PgWalUploadType,
|
||||||
|
dbName string,
|
||||||
|
walSegmentName string,
|
||||||
|
fullBackupWalStartSegment string,
|
||||||
|
fullBackupWalStopSegment string,
|
||||||
|
encryption backups_config.BackupEncryption,
|
||||||
|
) *backups_core.Backup {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
|
||||||
|
backup := &backups_core.Backup{
|
||||||
|
ID: uuid.New(),
|
||||||
|
DatabaseID: databaseID,
|
||||||
|
StorageID: storageID,
|
||||||
|
Status: backups_core.BackupStatusInProgress,
|
||||||
|
Encryption: encryption,
|
||||||
|
CreatedAt: now,
|
||||||
|
}
|
||||||
|
|
||||||
|
backup.GenerateFilename(dbName)
|
||||||
|
|
||||||
|
if uploadType == backups_core.PgWalUploadTypeBasebackup {
|
||||||
|
walBackupType := backups_core.PgWalBackupTypeFullBackup
|
||||||
|
backup.PgWalBackupType = &walBackupType
|
||||||
|
|
||||||
|
if fullBackupWalStartSegment != "" {
|
||||||
|
backup.PgFullBackupWalStartSegmentName = &fullBackupWalStartSegment
|
||||||
|
}
|
||||||
|
|
||||||
|
if fullBackupWalStopSegment != "" {
|
||||||
|
backup.PgFullBackupWalStopSegmentName = &fullBackupWalStopSegment
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
walBackupType := backups_core.PgWalBackupTypeWalSegment
|
||||||
|
backup.PgWalBackupType = &walBackupType
|
||||||
|
backup.PgWalSegmentName = &walSegmentName
|
||||||
|
}
|
||||||
|
|
||||||
|
return backup
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PostgreWalBackupService) streamToStorage(
|
||||||
|
ctx context.Context,
|
||||||
|
backup *backups_core.Backup,
|
||||||
|
backupConfig *backups_config.BackupConfig,
|
||||||
|
body io.Reader,
|
||||||
|
) (int64, error) {
|
||||||
|
if backupConfig.Encryption == backups_config.BackupEncryptionEncrypted {
|
||||||
|
return s.streamEncrypted(ctx, backup, backupConfig, body, backup.FileName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.streamDirect(ctx, backupConfig, body, backup.FileName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PostgreWalBackupService) streamDirect(
|
||||||
|
ctx context.Context,
|
||||||
|
backupConfig *backups_config.BackupConfig,
|
||||||
|
body io.Reader,
|
||||||
|
fileName string,
|
||||||
|
) (int64, error) {
|
||||||
|
cr := &countingReader{r: body}
|
||||||
|
|
||||||
|
if err := backupConfig.Storage.SaveFile(ctx, s.fieldEncryptor, s.logger, fileName, cr); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return cr.n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PostgreWalBackupService) streamEncrypted(
|
||||||
|
ctx context.Context,
|
||||||
|
backup *backups_core.Backup,
|
||||||
|
backupConfig *backups_config.BackupConfig,
|
||||||
|
body io.Reader,
|
||||||
|
fileName string,
|
||||||
|
) (int64, error) {
|
||||||
|
masterKey, err := s.secretKeyService.GetSecretKey()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to get master encryption key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pipeReader, pipeWriter := io.Pipe()
|
||||||
|
|
||||||
|
encryptionSetup, err := backup_encryption.SetupEncryptionWriter(
|
||||||
|
pipeWriter,
|
||||||
|
masterKey,
|
||||||
|
backup.ID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
_ = pipeWriter.Close()
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
copyErrCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
_, copyErr := io.Copy(encryptionSetup.Writer, body)
|
||||||
|
if copyErr != nil {
|
||||||
|
_ = encryptionSetup.Writer.Close()
|
||||||
|
_ = pipeWriter.CloseWithError(copyErr)
|
||||||
|
copyErrCh <- copyErr
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if closeErr := encryptionSetup.Writer.Close(); closeErr != nil {
|
||||||
|
_ = pipeWriter.CloseWithError(closeErr)
|
||||||
|
copyErrCh <- closeErr
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
copyErrCh <- pipeWriter.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
cr := &countingReader{r: pipeReader}
|
||||||
|
saveErr := backupConfig.Storage.SaveFile(ctx, s.fieldEncryptor, s.logger, fileName, cr)
|
||||||
|
copyErr := <-copyErrCh
|
||||||
|
|
||||||
|
if copyErr != nil {
|
||||||
|
return 0, copyErr
|
||||||
|
}
|
||||||
|
|
||||||
|
if saveErr != nil {
|
||||||
|
return 0, saveErr
|
||||||
|
}
|
||||||
|
|
||||||
|
backup.EncryptionSalt = &encryptionSetup.SaltBase64
|
||||||
|
backup.EncryptionIV = &encryptionSetup.NonceBase64
|
||||||
|
|
||||||
|
return cr.n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PostgreWalBackupService) markCompleted(backup *backups_core.Backup, sizeBytes int64) {
|
||||||
|
backup.Status = backups_core.BackupStatusCompleted
|
||||||
|
backup.BackupSizeMb = float64(sizeBytes) / (1024 * 1024)
|
||||||
|
|
||||||
|
if err := s.backupRepository.Save(backup); err != nil {
|
||||||
|
s.logger.Error(
|
||||||
|
"failed to mark WAL backup as completed",
|
||||||
|
"backupId",
|
||||||
|
backup.ID,
|
||||||
|
"error",
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PostgreWalBackupService) markFailed(backup *backups_core.Backup, errMsg string) {
|
||||||
|
backup.Status = backups_core.BackupStatusFailed
|
||||||
|
backup.FailMessage = &errMsg
|
||||||
|
|
||||||
|
if err := s.backupRepository.Save(backup); err != nil {
|
||||||
|
s.logger.Error("failed to mark WAL backup as failed", "backupId", backup.ID, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PostgreWalBackupService) GetNextFullBackupTime(
|
||||||
|
database *databases.Database,
|
||||||
|
) (*backups_dto.GetNextFullBackupTimeResponse, error) {
|
||||||
|
if err := s.validateWalBackupType(database); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get backup config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if backupConfig.BackupInterval == nil {
|
||||||
|
return nil, fmt.Errorf("no backup interval configured for database %s", database.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
lastFullBackup, err := s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(
|
||||||
|
database.ID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query last full backup: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var lastBackupTime *time.Time
|
||||||
|
if lastFullBackup != nil {
|
||||||
|
lastBackupTime = &lastFullBackup.CreatedAt
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
nextTime := backupConfig.BackupInterval.NextTriggerTime(now, lastBackupTime)
|
||||||
|
|
||||||
|
return &backups_dto.GetNextFullBackupTimeResponse{
|
||||||
|
NextFullBackupTime: nextTime,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReportError creates a FAILED backup record with the agent's error message.
|
||||||
|
func (s *PostgreWalBackupService) ReportError(
|
||||||
|
database *databases.Database,
|
||||||
|
errorMsg string,
|
||||||
|
) error {
|
||||||
|
if err := s.validateWalBackupType(database); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get backup config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if backupConfig.Storage == nil {
|
||||||
|
return fmt.Errorf("no storage configured for database %s", database.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
backup := &backups_core.Backup{
|
||||||
|
ID: uuid.New(),
|
||||||
|
DatabaseID: database.ID,
|
||||||
|
StorageID: backupConfig.Storage.ID,
|
||||||
|
Status: backups_core.BackupStatusFailed,
|
||||||
|
FailMessage: &errorMsg,
|
||||||
|
Encryption: backupConfig.Encryption,
|
||||||
|
CreatedAt: now,
|
||||||
|
}
|
||||||
|
|
||||||
|
backup.GenerateFilename(database.Name)
|
||||||
|
|
||||||
|
if err := s.backupRepository.Save(backup); err != nil {
|
||||||
|
return fmt.Errorf("failed to save error backup record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PostgreWalBackupService) resolveFullBackup(
|
||||||
|
databaseID uuid.UUID,
|
||||||
|
backupID *uuid.UUID,
|
||||||
|
) (*backups_core.Backup, error) {
|
||||||
|
if backupID != nil {
|
||||||
|
return s.backupRepository.FindCompletedFullWalBackupByID(databaseID, *backupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(databaseID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PostgreWalBackupService) validateRestoreWalChain(
|
||||||
|
fullBackup *backups_core.Backup,
|
||||||
|
walSegments []*backups_core.Backup,
|
||||||
|
) *backups_dto.GetRestorePlanErrorResponse {
|
||||||
|
if len(walSegments) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
stopSegment := ""
|
||||||
|
if fullBackup.PgFullBackupWalStopSegmentName != nil {
|
||||||
|
stopSegment = *fullBackup.PgFullBackupWalStopSegmentName
|
||||||
|
}
|
||||||
|
|
||||||
|
walCalculator := util_wal.NewWalCalculator(0)
|
||||||
|
expectedNext, err := walCalculator.NextSegment(stopSegment)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, seg := range walSegments {
|
||||||
|
segName := ""
|
||||||
|
if seg.PgWalSegmentName != nil {
|
||||||
|
segName = *seg.PgWalSegmentName
|
||||||
|
}
|
||||||
|
|
||||||
|
cmp, cmpErr := walCalculator.Compare(segName, stopSegment)
|
||||||
|
if cmpErr != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip segments that are <= stopSegment (they are part of the basebackup range)
|
||||||
|
if cmp <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if segName != expectedNext {
|
||||||
|
lastContiguous := stopSegment
|
||||||
|
// Walk back to find the segment before the gap
|
||||||
|
for _, prev := range walSegments {
|
||||||
|
prevName := ""
|
||||||
|
if prev.PgWalSegmentName != nil {
|
||||||
|
prevName = *prev.PgWalSegmentName
|
||||||
|
}
|
||||||
|
|
||||||
|
prevCmp, _ := walCalculator.Compare(prevName, stopSegment)
|
||||||
|
if prevCmp <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if prevName == segName {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
lastContiguous = prevName
|
||||||
|
}
|
||||||
|
|
||||||
|
return &backups_dto.GetRestorePlanErrorResponse{
|
||||||
|
Error: "wal_chain_broken",
|
||||||
|
Message: fmt.Sprintf(
|
||||||
|
"WAL chain has a gap after segment %s. Recovery is only possible up to this segment.",
|
||||||
|
lastContiguous,
|
||||||
|
),
|
||||||
|
LastContiguousSegment: lastContiguous,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedNext, err = walCalculator.NextSegment(segName)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PostgreWalBackupService) validateWalBackupType(database *databases.Database) error {
|
||||||
|
if database.Postgresql == nil ||
|
||||||
|
database.Postgresql.BackupType != postgresql.PostgresBackupTypeWalV1 {
|
||||||
|
return fmt.Errorf("database %s is not configured for WAL backups", database.ID)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type countingReader struct {
|
||||||
|
r io.Reader
|
||||||
|
n int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cr *countingReader) Read(p []byte) (n int, err error) {
|
||||||
|
n, err = cr.r.Read(p)
|
||||||
|
cr.n += int64(n)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package backups
|
package backups_services
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"databasus-backend/internal/features/backups/backups/backuping"
|
"databasus-backend/internal/features/backups/backups/backuping"
|
||||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||||
|
backups_dto "databasus-backend/internal/features/backups/backups/dto"
|
||||||
"databasus-backend/internal/features/backups/backups/encryption"
|
"databasus-backend/internal/features/backups/backups/encryption"
|
||||||
backups_config "databasus-backend/internal/features/backups/config"
|
backups_config "databasus-backend/internal/features/backups/config"
|
||||||
"databasus-backend/internal/features/databases"
|
"databasus-backend/internal/features/databases"
|
||||||
@@ -108,7 +109,7 @@ func (s *BackupService) GetBackups(
|
|||||||
user *users_models.User,
|
user *users_models.User,
|
||||||
databaseID uuid.UUID,
|
databaseID uuid.UUID,
|
||||||
limit, offset int,
|
limit, offset int,
|
||||||
) (*GetBackupsResponse, error) {
|
) (*backups_dto.GetBackupsResponse, error) {
|
||||||
database, err := s.databaseService.GetDatabaseByID(databaseID)
|
database, err := s.databaseService.GetDatabaseByID(databaseID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -143,7 +144,7 @@ func (s *BackupService) GetBackups(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &GetBackupsResponse{
|
return &backups_dto.GetBackupsResponse{
|
||||||
Backups: backups,
|
Backups: backups,
|
||||||
Total: total,
|
Total: total,
|
||||||
Limit: limit,
|
Limit: limit,
|
||||||
@@ -274,7 +275,7 @@ func (s *BackupService) GetBackupFile(
|
|||||||
database.WorkspaceID,
|
database.WorkspaceID,
|
||||||
)
|
)
|
||||||
|
|
||||||
reader, err := s.getBackupReader(backupID)
|
reader, err := s.GetBackupReader(backupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
@@ -282,39 +283,9 @@ func (s *BackupService) GetBackupFile(
|
|||||||
return reader, backup, database, nil
|
return reader, backup, database, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
|
// GetBackupReader returns a reader for the backup file.
|
||||||
dbBackupsInProgress, err := s.backupRepository.FindByDatabaseIdAndStatus(
|
// If encrypted, wraps with DecryptionReader.
|
||||||
databaseID,
|
func (s *BackupService) GetBackupReader(backupID uuid.UUID) (io.ReadCloser, error) {
|
||||||
backups_core.BackupStatusInProgress,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(dbBackupsInProgress) > 0 {
|
|
||||||
return errors.New("backup is in progress, storage cannot be removed")
|
|
||||||
}
|
|
||||||
|
|
||||||
dbBackups, err := s.backupRepository.FindByDatabaseID(
|
|
||||||
databaseID,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, dbBackup := range dbBackups {
|
|
||||||
err := s.backupCleaner.DeleteBackup(dbBackup)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetBackupReader returns a reader for the backup file
|
|
||||||
// If encrypted, wraps with DecryptionReader
|
|
||||||
func (s *BackupService) getBackupReader(backupID uuid.UUID) (io.ReadCloser, error) {
|
|
||||||
backup, err := s.backupRepository.FindByID(backupID)
|
backup, err := s.backupRepository.FindByID(backupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to find backup: %w", err)
|
return nil, fmt.Errorf("failed to find backup: %w", err)
|
||||||
@@ -394,7 +365,7 @@ func (s *BackupService) getBackupReader(backupID uuid.UUID) (io.ReadCloser, erro
|
|||||||
|
|
||||||
s.logger.Info("Returning encrypted backup with decryption", "backupId", backupID)
|
s.logger.Info("Returning encrypted backup with decryption", "backupId", backupID)
|
||||||
|
|
||||||
return &DecryptionReaderCloser{
|
return &backups_dto.DecryptionReaderCloser{
|
||||||
DecryptionReader: decryptionReader,
|
DecryptionReader: decryptionReader,
|
||||||
BaseReader: fileReader,
|
BaseReader: fileReader,
|
||||||
}, nil
|
}, nil
|
||||||
@@ -465,7 +436,7 @@ func (s *BackupService) GetBackupFileWithoutAuth(
|
|||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
reader, err := s.getBackupReader(backupID)
|
reader, err := s.GetBackupReader(backupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
@@ -501,6 +472,36 @@ func (s *BackupService) UnregisterDownload(userID uuid.UUID) {
|
|||||||
s.downloadTokenService.UnregisterDownload(userID)
|
s.downloadTokenService.UnregisterDownload(userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
|
||||||
|
dbBackupsInProgress, err := s.backupRepository.FindByDatabaseIdAndStatus(
|
||||||
|
databaseID,
|
||||||
|
backups_core.BackupStatusInProgress,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(dbBackupsInProgress) > 0 {
|
||||||
|
return errors.New("backup is in progress, storage cannot be removed")
|
||||||
|
}
|
||||||
|
|
||||||
|
dbBackups, err := s.backupRepository.FindByDatabaseID(
|
||||||
|
databaseID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dbBackup := range dbBackups {
|
||||||
|
err := s.backupCleaner.DeleteBackup(dbBackup)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *BackupService) generateBackupFilename(
|
func (s *BackupService) generateBackupFilename(
|
||||||
backup *backups_core.Backup,
|
backup *backups_core.Backup,
|
||||||
database *databases.Database,
|
database *databases.Database,
|
||||||
@@ -2,7 +2,6 @@ package usecases_mariadb
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -437,40 +436,22 @@ func (uc *CreateMariadbBackupUsecase) setupBackupEncryption(
|
|||||||
return storageWriter, nil, metadata, nil
|
return storageWriter, nil, metadata, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
salt, err := backup_encryption.GenerateSalt()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, metadata, fmt.Errorf("failed to generate salt: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
nonce, err := backup_encryption.GenerateNonce()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, metadata, fmt.Errorf("failed to generate nonce: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
masterKey, err := uc.secretKeyService.GetSecretKey()
|
masterKey, err := uc.secretKeyService.GetSecretKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, metadata, fmt.Errorf("failed to get master key: %w", err)
|
return nil, nil, metadata, fmt.Errorf("failed to get master key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
encWriter, err := backup_encryption.NewEncryptionWriter(
|
encSetup, err := backup_encryption.SetupEncryptionWriter(storageWriter, masterKey, backupID)
|
||||||
storageWriter,
|
|
||||||
masterKey,
|
|
||||||
backupID,
|
|
||||||
salt,
|
|
||||||
nonce,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, metadata, fmt.Errorf("failed to create encrypting writer: %w", err)
|
return nil, nil, metadata, err
|
||||||
}
|
}
|
||||||
|
|
||||||
saltBase64 := base64.StdEncoding.EncodeToString(salt)
|
metadata.EncryptionSalt = &encSetup.SaltBase64
|
||||||
nonceBase64 := base64.StdEncoding.EncodeToString(nonce)
|
metadata.EncryptionIV = &encSetup.NonceBase64
|
||||||
metadata.EncryptionSalt = &saltBase64
|
|
||||||
metadata.EncryptionIV = &nonceBase64
|
|
||||||
metadata.Encryption = backups_config.BackupEncryptionEncrypted
|
metadata.Encryption = backups_config.BackupEncryptionEncrypted
|
||||||
|
|
||||||
uc.logger.Info("Encryption enabled for backup", "backupId", backupID)
|
uc.logger.Info("Encryption enabled for backup", "backupId", backupID)
|
||||||
return encWriter, encWriter, metadata, nil
|
return encSetup.Writer, encSetup.Writer, metadata, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (uc *CreateMariadbBackupUsecase) cleanupOnCancellation(
|
func (uc *CreateMariadbBackupUsecase) cleanupOnCancellation(
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package usecases_mongodb
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -277,41 +276,21 @@ func (uc *CreateMongodbBackupUsecase) setupBackupEncryption(
|
|||||||
return storageWriter, nil, backupMetadata, nil
|
return storageWriter, nil, backupMetadata, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
salt, err := backup_encryption.GenerateSalt()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, backupMetadata, fmt.Errorf("failed to generate salt: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
nonce, err := backup_encryption.GenerateNonce()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, backupMetadata, fmt.Errorf("failed to generate nonce: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
masterKey, err := uc.secretKeyService.GetSecretKey()
|
masterKey, err := uc.secretKeyService.GetSecretKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, backupMetadata, fmt.Errorf("failed to get master key: %w", err)
|
return nil, nil, backupMetadata, fmt.Errorf("failed to get master key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
encryptionWriter, err := backup_encryption.NewEncryptionWriter(
|
encSetup, err := backup_encryption.SetupEncryptionWriter(storageWriter, masterKey, backupID)
|
||||||
storageWriter,
|
|
||||||
masterKey,
|
|
||||||
backupID,
|
|
||||||
salt,
|
|
||||||
nonce,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, backupMetadata, fmt.Errorf("failed to create encryption writer: %w", err)
|
return nil, nil, backupMetadata, err
|
||||||
}
|
}
|
||||||
|
|
||||||
saltBase64 := base64.StdEncoding.EncodeToString(salt)
|
|
||||||
nonceBase64 := base64.StdEncoding.EncodeToString(nonce)
|
|
||||||
|
|
||||||
backupMetadata.BackupID = backupID
|
|
||||||
backupMetadata.Encryption = backups_config.BackupEncryptionEncrypted
|
backupMetadata.Encryption = backups_config.BackupEncryptionEncrypted
|
||||||
backupMetadata.EncryptionSalt = &saltBase64
|
backupMetadata.EncryptionSalt = &encSetup.SaltBase64
|
||||||
backupMetadata.EncryptionIV = &nonceBase64
|
backupMetadata.EncryptionIV = &encSetup.NonceBase64
|
||||||
|
|
||||||
return encryptionWriter, encryptionWriter, backupMetadata, nil
|
return encSetup.Writer, encSetup.Writer, backupMetadata, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (uc *CreateMongodbBackupUsecase) copyWithShutdownCheck(
|
func (uc *CreateMongodbBackupUsecase) copyWithShutdownCheck(
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package usecases_mysql
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -448,40 +447,22 @@ func (uc *CreateMysqlBackupUsecase) setupBackupEncryption(
|
|||||||
return storageWriter, nil, metadata, nil
|
return storageWriter, nil, metadata, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
salt, err := backup_encryption.GenerateSalt()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, metadata, fmt.Errorf("failed to generate salt: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
nonce, err := backup_encryption.GenerateNonce()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, metadata, fmt.Errorf("failed to generate nonce: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
masterKey, err := uc.secretKeyService.GetSecretKey()
|
masterKey, err := uc.secretKeyService.GetSecretKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, metadata, fmt.Errorf("failed to get master key: %w", err)
|
return nil, nil, metadata, fmt.Errorf("failed to get master key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
encWriter, err := backup_encryption.NewEncryptionWriter(
|
encSetup, err := backup_encryption.SetupEncryptionWriter(storageWriter, masterKey, backupID)
|
||||||
storageWriter,
|
|
||||||
masterKey,
|
|
||||||
backupID,
|
|
||||||
salt,
|
|
||||||
nonce,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, metadata, fmt.Errorf("failed to create encrypting writer: %w", err)
|
return nil, nil, metadata, err
|
||||||
}
|
}
|
||||||
|
|
||||||
saltBase64 := base64.StdEncoding.EncodeToString(salt)
|
metadata.EncryptionSalt = &encSetup.SaltBase64
|
||||||
nonceBase64 := base64.StdEncoding.EncodeToString(nonce)
|
metadata.EncryptionIV = &encSetup.NonceBase64
|
||||||
metadata.EncryptionSalt = &saltBase64
|
|
||||||
metadata.EncryptionIV = &nonceBase64
|
|
||||||
metadata.Encryption = backups_config.BackupEncryptionEncrypted
|
metadata.Encryption = backups_config.BackupEncryptionEncrypted
|
||||||
|
|
||||||
uc.logger.Info("Encryption enabled for backup", "backupId", backupID)
|
uc.logger.Info("Encryption enabled for backup", "backupId", backupID)
|
||||||
return encWriter, encWriter, metadata, nil
|
return encSetup.Writer, encSetup.Writer, metadata, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (uc *CreateMysqlBackupUsecase) cleanupOnCancellation(
|
func (uc *CreateMysqlBackupUsecase) cleanupOnCancellation(
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package usecases_postgresql
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -492,40 +491,22 @@ func (uc *CreatePostgresqlBackupUsecase) setupBackupEncryption(
|
|||||||
return storageWriter, nil, metadata, nil
|
return storageWriter, nil, metadata, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
salt, err := backup_encryption.GenerateSalt()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, metadata, fmt.Errorf("failed to generate salt: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
nonce, err := backup_encryption.GenerateNonce()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, metadata, fmt.Errorf("failed to generate nonce: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
masterKey, err := uc.secretKeyService.GetSecretKey()
|
masterKey, err := uc.secretKeyService.GetSecretKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, metadata, fmt.Errorf("failed to get master key: %w", err)
|
return nil, nil, metadata, fmt.Errorf("failed to get master key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
encWriter, err := backup_encryption.NewEncryptionWriter(
|
encSetup, err := backup_encryption.SetupEncryptionWriter(storageWriter, masterKey, backupID)
|
||||||
storageWriter,
|
|
||||||
masterKey,
|
|
||||||
backupID,
|
|
||||||
salt,
|
|
||||||
nonce,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, metadata, fmt.Errorf("failed to create encrypting writer: %w", err)
|
return nil, nil, metadata, err
|
||||||
}
|
}
|
||||||
|
|
||||||
saltBase64 := base64.StdEncoding.EncodeToString(salt)
|
metadata.EncryptionSalt = &encSetup.SaltBase64
|
||||||
nonceBase64 := base64.StdEncoding.EncodeToString(nonce)
|
metadata.EncryptionIV = &encSetup.NonceBase64
|
||||||
metadata.EncryptionSalt = &saltBase64
|
|
||||||
metadata.EncryptionIV = &nonceBase64
|
|
||||||
metadata.Encryption = backups_config.BackupEncryptionEncrypted
|
metadata.Encryption = backups_config.BackupEncryptionEncrypted
|
||||||
|
|
||||||
uc.logger.Info("Encryption enabled for backup", "backupId", backupID)
|
uc.logger.Info("Encryption enabled for backup", "backupId", backupID)
|
||||||
return encWriter, encWriter, metadata, nil
|
return encSetup.Writer, encSetup.Writer, metadata, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (uc *CreatePostgresqlBackupUsecase) cleanupOnCancellation(
|
func (uc *CreatePostgresqlBackupUsecase) cleanupOnCancellation(
|
||||||
|
|||||||
@@ -29,6 +29,11 @@ func (c *DatabaseController) RegisterRoutes(router *gin.RouterGroup) {
|
|||||||
router.GET("/databases/notifier/:id/databases-count", c.CountDatabasesByNotifier)
|
router.GET("/databases/notifier/:id/databases-count", c.CountDatabasesByNotifier)
|
||||||
router.POST("/databases/is-readonly", c.IsUserReadOnly)
|
router.POST("/databases/is-readonly", c.IsUserReadOnly)
|
||||||
router.POST("/databases/create-readonly-user", c.CreateReadOnlyUser)
|
router.POST("/databases/create-readonly-user", c.CreateReadOnlyUser)
|
||||||
|
router.POST("/databases/:id/regenerate-token", c.RegenerateAgentToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DatabaseController) RegisterPublicRoutes(router *gin.RouterGroup) {
|
||||||
|
router.POST("/databases/verify-token", c.VerifyAgentToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateDatabase
|
// CreateDatabase
|
||||||
@@ -438,3 +443,61 @@ func (c *DatabaseController) CreateReadOnlyUser(ctx *gin.Context) {
|
|||||||
Password: password,
|
Password: password,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegenerateAgentToken
|
||||||
|
// @Summary Regenerate agent token for a database
|
||||||
|
// @Description Generate a new agent token for the database. The token is returned once and stored as a hash.
|
||||||
|
// @Tags databases
|
||||||
|
// @Produce json
|
||||||
|
// @Security BearerAuth
|
||||||
|
// @Param id path string true "Database ID"
|
||||||
|
// @Success 200 {object} map[string]string
|
||||||
|
// @Failure 400 {object} map[string]string
|
||||||
|
// @Failure 401 {object} map[string]string
|
||||||
|
// @Router /databases/{id}/regenerate-token [post]
|
||||||
|
func (c *DatabaseController) RegenerateAgentToken(ctx *gin.Context) {
|
||||||
|
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := uuid.Parse(ctx.Param("id"))
|
||||||
|
if err != nil {
|
||||||
|
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := c.databaseService.RegenerateAgentToken(user, id)
|
||||||
|
if err != nil {
|
||||||
|
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.JSON(http.StatusOK, gin.H{"token": token})
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyAgentToken
|
||||||
|
// @Summary Verify agent token
|
||||||
|
// @Description Verify that a given agent token is valid for any database
|
||||||
|
// @Tags databases
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Param request body VerifyAgentTokenRequest true "Token to verify"
|
||||||
|
// @Success 200 {object} map[string]string
|
||||||
|
// @Failure 401 {object} map[string]string
|
||||||
|
// @Router /databases/verify-token [post]
|
||||||
|
func (c *DatabaseController) VerifyAgentToken(ctx *gin.Context) {
|
||||||
|
var request VerifyAgentTokenRequest
|
||||||
|
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||||
|
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.databaseService.VerifyAgentToken(request.Token); err != nil {
|
||||||
|
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.JSON(http.StatusOK, gin.H{"message": "token is valid"})
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,10 +13,13 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"databasus-backend/internal/config"
|
"databasus-backend/internal/config"
|
||||||
|
"databasus-backend/internal/features/audit_logs"
|
||||||
"databasus-backend/internal/features/databases/databases/mariadb"
|
"databasus-backend/internal/features/databases/databases/mariadb"
|
||||||
"databasus-backend/internal/features/databases/databases/mongodb"
|
"databasus-backend/internal/features/databases/databases/mongodb"
|
||||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||||
users_enums "databasus-backend/internal/features/users/enums"
|
users_enums "databasus-backend/internal/features/users/enums"
|
||||||
|
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||||
|
users_services "databasus-backend/internal/features/users/services"
|
||||||
users_testing "databasus-backend/internal/features/users/testing"
|
users_testing "databasus-backend/internal/features/users/testing"
|
||||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||||
@@ -144,6 +147,66 @@ func Test_CreateDatabase_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testin
|
|||||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_CreateDatabase_WalV1Type_NoConnectionFieldsRequired(t *testing.T) {
|
||||||
|
router := createTestRouter()
|
||||||
|
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||||
|
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||||
|
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||||
|
|
||||||
|
request := Database{
|
||||||
|
Name: "Test WAL Database",
|
||||||
|
WorkspaceID: &workspace.ID,
|
||||||
|
Type: DatabaseTypePostgres,
|
||||||
|
Postgresql: &postgresql.PostgresqlDatabase{
|
||||||
|
BackupType: postgresql.PostgresBackupTypeWalV1,
|
||||||
|
CpuCount: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var response Database
|
||||||
|
test_utils.MakePostRequestAndUnmarshal(
|
||||||
|
t,
|
||||||
|
router,
|
||||||
|
"/api/v1/databases/create",
|
||||||
|
"Bearer "+owner.Token,
|
||||||
|
request,
|
||||||
|
http.StatusCreated,
|
||||||
|
&response,
|
||||||
|
)
|
||||||
|
defer RemoveTestDatabase(&response)
|
||||||
|
|
||||||
|
assert.Equal(t, "Test WAL Database", response.Name)
|
||||||
|
assert.NotEqual(t, uuid.Nil, response.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_CreateDatabase_PgDumpType_ConnectionFieldsRequired(t *testing.T) {
|
||||||
|
router := createTestRouter()
|
||||||
|
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||||
|
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||||
|
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||||
|
|
||||||
|
request := Database{
|
||||||
|
Name: "Test PG_DUMP Database",
|
||||||
|
WorkspaceID: &workspace.ID,
|
||||||
|
Type: DatabaseTypePostgres,
|
||||||
|
Postgresql: &postgresql.PostgresqlDatabase{
|
||||||
|
BackupType: postgresql.PostgresBackupTypePgDump,
|
||||||
|
CpuCount: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testResp := test_utils.MakePostRequest(
|
||||||
|
t,
|
||||||
|
router,
|
||||||
|
"/api/v1/databases/create",
|
||||||
|
"Bearer "+owner.Token,
|
||||||
|
request,
|
||||||
|
http.StatusBadRequest,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.Contains(t, string(testResp.Body), "host is required")
|
||||||
|
}
|
||||||
|
|
||||||
func Test_UpdateDatabase_PermissionsEnforced(t *testing.T) {
|
func Test_UpdateDatabase_PermissionsEnforced(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -256,6 +319,52 @@ func Test_UpdateDatabase_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testin
|
|||||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_UpdateDatabase_WhenDatabaseTypeChanged_ReturnsBadRequest(t *testing.T) {
|
||||||
|
router := createTestRouter()
|
||||||
|
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||||
|
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||||
|
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||||
|
|
||||||
|
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||||
|
defer RemoveTestDatabase(database)
|
||||||
|
|
||||||
|
database.Type = DatabaseTypeMysql
|
||||||
|
|
||||||
|
testResp := test_utils.MakePostRequest(
|
||||||
|
t,
|
||||||
|
router,
|
||||||
|
"/api/v1/databases/update",
|
||||||
|
"Bearer "+owner.Token,
|
||||||
|
database,
|
||||||
|
http.StatusBadRequest,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.Contains(t, string(testResp.Body), "database type cannot be changed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_UpdateDatabase_WhenBackupTypeChanged_ReturnsBadRequest(t *testing.T) {
|
||||||
|
router := createTestRouter()
|
||||||
|
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||||
|
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||||
|
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||||
|
|
||||||
|
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||||
|
defer RemoveTestDatabase(database)
|
||||||
|
|
||||||
|
database.Postgresql.BackupType = postgresql.PostgresBackupTypeWalV1
|
||||||
|
|
||||||
|
testResp := test_utils.MakePostRequest(
|
||||||
|
t,
|
||||||
|
router,
|
||||||
|
"/api/v1/databases/update",
|
||||||
|
"Bearer "+owner.Token,
|
||||||
|
database,
|
||||||
|
http.StatusBadRequest,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.Contains(t, string(testResp.Body), "backup type cannot be changed")
|
||||||
|
}
|
||||||
|
|
||||||
func Test_DeleteDatabase_PermissionsEnforced(t *testing.T) {
|
func Test_DeleteDatabase_PermissionsEnforced(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -1050,6 +1159,87 @@ func Test_TestConnection_PermissionsEnforced(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_RegenerateAgentToken_ReturnsToken(t *testing.T) {
|
||||||
|
router := createTestRouter()
|
||||||
|
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||||
|
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||||
|
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||||
|
|
||||||
|
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||||
|
defer RemoveTestDatabase(database)
|
||||||
|
|
||||||
|
var response map[string]string
|
||||||
|
test_utils.MakePostRequestAndUnmarshal(
|
||||||
|
t,
|
||||||
|
router,
|
||||||
|
"/api/v1/databases/"+database.ID.String()+"/regenerate-token",
|
||||||
|
"Bearer "+owner.Token,
|
||||||
|
nil,
|
||||||
|
http.StatusOK,
|
||||||
|
&response,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.NotEmpty(t, response["token"])
|
||||||
|
assert.Len(t, response["token"], 32)
|
||||||
|
|
||||||
|
var updatedDatabase Database
|
||||||
|
test_utils.MakeGetRequestAndUnmarshal(
|
||||||
|
t,
|
||||||
|
router,
|
||||||
|
"/api/v1/databases/"+database.ID.String(),
|
||||||
|
"Bearer "+owner.Token,
|
||||||
|
http.StatusOK,
|
||||||
|
&updatedDatabase,
|
||||||
|
)
|
||||||
|
assert.True(t, updatedDatabase.IsAgentTokenGenerated)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_VerifyAgentToken_WithValidToken_Succeeds(t *testing.T) {
|
||||||
|
router := createTestRouter()
|
||||||
|
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||||
|
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||||
|
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||||
|
|
||||||
|
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||||
|
defer RemoveTestDatabase(database)
|
||||||
|
|
||||||
|
var regenerateResponse map[string]string
|
||||||
|
test_utils.MakePostRequestAndUnmarshal(
|
||||||
|
t,
|
||||||
|
router,
|
||||||
|
"/api/v1/databases/"+database.ID.String()+"/regenerate-token",
|
||||||
|
"Bearer "+owner.Token,
|
||||||
|
nil,
|
||||||
|
http.StatusOK,
|
||||||
|
®enerateResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
token := regenerateResponse["token"]
|
||||||
|
assert.NotEmpty(t, token)
|
||||||
|
|
||||||
|
w := workspaces_testing.MakeAPIRequest(
|
||||||
|
router,
|
||||||
|
"POST",
|
||||||
|
"/api/v1/databases/verify-token",
|
||||||
|
"",
|
||||||
|
VerifyAgentTokenRequest{Token: token},
|
||||||
|
)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_VerifyAgentToken_WithInvalidToken_Returns401(t *testing.T) {
|
||||||
|
router := createTestRouter()
|
||||||
|
|
||||||
|
w := workspaces_testing.MakeAPIRequest(
|
||||||
|
router,
|
||||||
|
"POST",
|
||||||
|
"/api/v1/databases/verify-token",
|
||||||
|
"",
|
||||||
|
VerifyAgentTokenRequest{Token: "invalidtoken00000000000000000000"},
|
||||||
|
)
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
func createTestDatabaseViaAPI(
|
func createTestDatabaseViaAPI(
|
||||||
name string,
|
name string,
|
||||||
workspaceID uuid.UUID,
|
workspaceID uuid.UUID,
|
||||||
@@ -1101,11 +1291,20 @@ func createTestDatabaseViaAPI(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func createTestRouter() *gin.Engine {
|
func createTestRouter() *gin.Engine {
|
||||||
router := workspaces_testing.CreateTestRouter(
|
gin.SetMode(gin.TestMode)
|
||||||
workspaces_controllers.GetWorkspaceController(),
|
router := gin.New()
|
||||||
workspaces_controllers.GetMembershipController(),
|
|
||||||
GetDatabaseController(),
|
v1 := router.Group("/api/v1")
|
||||||
)
|
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
|
||||||
|
|
||||||
|
workspaces_controllers.GetWorkspaceController().RegisterRoutes(protected.(*gin.RouterGroup))
|
||||||
|
workspaces_controllers.GetMembershipController().RegisterRoutes(protected.(*gin.RouterGroup))
|
||||||
|
GetDatabaseController().RegisterRoutes(protected.(*gin.RouterGroup))
|
||||||
|
|
||||||
|
GetDatabaseController().RegisterPublicRoutes(v1)
|
||||||
|
|
||||||
|
audit_logs.SetupDependencies()
|
||||||
|
|
||||||
return router
|
return router
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1118,13 +1317,14 @@ func getTestPostgresConfig() *postgresql.PostgresqlDatabase {
|
|||||||
|
|
||||||
testDbName := "testdb"
|
testDbName := "testdb"
|
||||||
return &postgresql.PostgresqlDatabase{
|
return &postgresql.PostgresqlDatabase{
|
||||||
Version: tools.PostgresqlVersion16,
|
BackupType: postgresql.PostgresBackupTypePgDump,
|
||||||
Host: config.GetEnv().TestLocalhost,
|
Version: tools.PostgresqlVersion16,
|
||||||
Port: port,
|
Host: config.GetEnv().TestLocalhost,
|
||||||
Username: "testuser",
|
Port: port,
|
||||||
Password: "testpassword",
|
Username: "testuser",
|
||||||
Database: &testDbName,
|
Password: "testpassword",
|
||||||
CpuCount: 1,
|
Database: &testDbName,
|
||||||
|
CpuCount: 1,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package postgresql
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"databasus-backend/internal/config"
|
||||||
"databasus-backend/internal/util/encryption"
|
"databasus-backend/internal/util/encryption"
|
||||||
"databasus-backend/internal/util/tools"
|
"databasus-backend/internal/util/tools"
|
||||||
"errors"
|
"errors"
|
||||||
@@ -17,6 +18,13 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type PostgresBackupType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
PostgresBackupTypePgDump PostgresBackupType = "PG_DUMP"
|
||||||
|
PostgresBackupTypeWalV1 PostgresBackupType = "WAL_V1"
|
||||||
|
)
|
||||||
|
|
||||||
type PostgresqlDatabase struct {
|
type PostgresqlDatabase struct {
|
||||||
ID uuid.UUID `json:"id" gorm:"primaryKey;type:uuid;default:gen_random_uuid()"`
|
ID uuid.UUID `json:"id" gorm:"primaryKey;type:uuid;default:gen_random_uuid()"`
|
||||||
|
|
||||||
@@ -24,11 +32,13 @@ type PostgresqlDatabase struct {
|
|||||||
|
|
||||||
Version tools.PostgresqlVersion `json:"version" gorm:"type:text;not null"`
|
Version tools.PostgresqlVersion `json:"version" gorm:"type:text;not null"`
|
||||||
|
|
||||||
// connection data
|
BackupType PostgresBackupType `json:"backupType" gorm:"column:backup_type;type:text;not null;default:'PG_DUMP'"`
|
||||||
Host string `json:"host" gorm:"type:text;not null"`
|
|
||||||
Port int `json:"port" gorm:"type:int;not null"`
|
// connection data — required for PG_DUMP, optional for WAL_V1
|
||||||
Username string `json:"username" gorm:"type:text;not null"`
|
Host string `json:"host" gorm:"type:text"`
|
||||||
Password string `json:"password" gorm:"type:text;not null"`
|
Port int `json:"port" gorm:"type:int"`
|
||||||
|
Username string `json:"username" gorm:"type:text"`
|
||||||
|
Password string `json:"password" gorm:"type:text"`
|
||||||
Database *string `json:"database" gorm:"type:text"`
|
Database *string `json:"database" gorm:"type:text"`
|
||||||
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
|
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
|
||||||
|
|
||||||
@@ -66,20 +76,30 @@ func (p *PostgresqlDatabase) AfterFind(_ *gorm.DB) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *PostgresqlDatabase) Validate() error {
|
func (p *PostgresqlDatabase) Validate() error {
|
||||||
if p.Host == "" {
|
if p.BackupType == "" {
|
||||||
return errors.New("host is required")
|
p.BackupType = PostgresBackupTypePgDump
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.Port == 0 {
|
if p.BackupType == PostgresBackupTypePgDump && config.GetEnv().IsCloud {
|
||||||
return errors.New("port is required")
|
return errors.New("PG_DUMP backup type is not supported in cloud mode")
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.Username == "" {
|
if p.BackupType == PostgresBackupTypePgDump {
|
||||||
return errors.New("username is required")
|
if p.Host == "" {
|
||||||
}
|
return errors.New("host is required")
|
||||||
|
}
|
||||||
|
|
||||||
if p.Password == "" {
|
if p.Port == 0 {
|
||||||
return errors.New("password is required")
|
return errors.New("port is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.Username == "" {
|
||||||
|
return errors.New("username is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.Password == "" {
|
||||||
|
return errors.New("password is required")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.CpuCount <= 0 {
|
if p.CpuCount <= 0 {
|
||||||
@@ -90,7 +110,7 @@ func (p *PostgresqlDatabase) Validate() error {
|
|||||||
// Databasus runs an internal PostgreSQL instance that should not be backed up through the UI
|
// Databasus runs an internal PostgreSQL instance that should not be backed up through the UI
|
||||||
// because it would expose internal metadata to non-system administrators.
|
// because it would expose internal metadata to non-system administrators.
|
||||||
// To properly backup Databasus, see: https://databasus.com/faq#backup-databasus
|
// To properly backup Databasus, see: https://databasus.com/faq#backup-databasus
|
||||||
if p.Database != nil && *p.Database != "" {
|
if p.BackupType == PostgresBackupTypePgDump && p.Database != nil && *p.Database != "" {
|
||||||
localhostHosts := []string{
|
localhostHosts := []string{
|
||||||
"localhost",
|
"localhost",
|
||||||
"127.0.0.1",
|
"127.0.0.1",
|
||||||
@@ -130,6 +150,10 @@ func (p *PostgresqlDatabase) TestConnection(
|
|||||||
encryptor encryption.FieldEncryptor,
|
encryptor encryption.FieldEncryptor,
|
||||||
databaseID uuid.UUID,
|
databaseID uuid.UUID,
|
||||||
) error {
|
) error {
|
||||||
|
if p.BackupType == PostgresBackupTypeWalV1 {
|
||||||
|
return errors.New("test connection is not supported for WAL backup type")
|
||||||
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@@ -144,7 +168,21 @@ func (p *PostgresqlDatabase) HideSensitiveData() {
|
|||||||
p.Password = ""
|
p.Password = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *PostgresqlDatabase) ValidateUpdate(old *PostgresqlDatabase) error {
|
||||||
|
// BackupType cannot be changed after creation — the full backup structure
|
||||||
|
// (WAL hierarchy, storage files, cleanup logic) is built around
|
||||||
|
// the type chosen at creation time. Automatically migrating this state is
|
||||||
|
// error-prone; it is safer for the user to create a new database and
|
||||||
|
// remove the old one.
|
||||||
|
if old.BackupType != p.BackupType {
|
||||||
|
return errors.New("backup type cannot be changed; create a new database instead")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (p *PostgresqlDatabase) Update(incoming *PostgresqlDatabase) {
|
func (p *PostgresqlDatabase) Update(incoming *PostgresqlDatabase) {
|
||||||
|
p.BackupType = incoming.BackupType
|
||||||
p.Version = incoming.Version
|
p.Version = incoming.Version
|
||||||
p.Host = incoming.Host
|
p.Host = incoming.Host
|
||||||
p.Port = incoming.Port
|
p.Port = incoming.Port
|
||||||
@@ -181,6 +219,10 @@ func (p *PostgresqlDatabase) PopulateDbData(
|
|||||||
encryptor encryption.FieldEncryptor,
|
encryptor encryption.FieldEncryptor,
|
||||||
databaseID uuid.UUID,
|
databaseID uuid.UUID,
|
||||||
) error {
|
) error {
|
||||||
|
if p.BackupType == PostgresBackupTypeWalV1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
return p.PopulateVersion(logger, encryptor, databaseID)
|
return p.PopulateVersion(logger, encryptor, databaseID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -243,6 +285,10 @@ func (p *PostgresqlDatabase) IsUserReadOnly(
|
|||||||
encryptor encryption.FieldEncryptor,
|
encryptor encryption.FieldEncryptor,
|
||||||
databaseID uuid.UUID,
|
databaseID uuid.UUID,
|
||||||
) (bool, []string, error) {
|
) (bool, []string, error) {
|
||||||
|
if p.BackupType == PostgresBackupTypeWalV1 {
|
||||||
|
return false, nil, errors.New("read-only check is not supported for WAL backup type")
|
||||||
|
}
|
||||||
|
|
||||||
password, err := decryptPasswordIfNeeded(p.Password, encryptor, databaseID)
|
password, err := decryptPasswordIfNeeded(p.Password, encryptor, databaseID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, nil, fmt.Errorf("failed to decrypt password: %w", err)
|
return false, nil, fmt.Errorf("failed to decrypt password: %w", err)
|
||||||
@@ -415,6 +461,10 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
|||||||
encryptor encryption.FieldEncryptor,
|
encryptor encryption.FieldEncryptor,
|
||||||
databaseID uuid.UUID,
|
databaseID uuid.UUID,
|
||||||
) (string, string, error) {
|
) (string, string, error) {
|
||||||
|
if p.BackupType == PostgresBackupTypeWalV1 {
|
||||||
|
return "", "", errors.New("read-only user creation is not supported for WAL backup type")
|
||||||
|
}
|
||||||
|
|
||||||
password, err := decryptPasswordIfNeeded(p.Password, encryptor, databaseID)
|
password, err := decryptPasswordIfNeeded(p.Password, encryptor, databaseID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", fmt.Errorf("failed to decrypt password: %w", err)
|
return "", "", fmt.Errorf("failed to decrypt password: %w", err)
|
||||||
|
|||||||
@@ -9,3 +9,7 @@ type IsReadOnlyResponse struct {
|
|||||||
IsReadOnly bool `json:"isReadOnly"`
|
IsReadOnly bool `json:"isReadOnly"`
|
||||||
Privileges []string `json:"privileges"`
|
Privileges []string `json:"privileges"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type VerifyAgentTokenRequest struct {
|
||||||
|
Token string `json:"token" binding:"required"`
|
||||||
|
}
|
||||||
|
|||||||
@@ -37,6 +37,9 @@ type Database struct {
|
|||||||
LastBackupErrorMessage *string `json:"lastBackupErrorMessage,omitempty" gorm:"column:last_backup_error_message;type:text"`
|
LastBackupErrorMessage *string `json:"lastBackupErrorMessage,omitempty" gorm:"column:last_backup_error_message;type:text"`
|
||||||
|
|
||||||
HealthStatus *HealthStatus `json:"healthStatus" gorm:"column:health_status;type:text;not null"`
|
HealthStatus *HealthStatus `json:"healthStatus" gorm:"column:health_status;type:text;not null"`
|
||||||
|
|
||||||
|
AgentToken *string `json:"-" gorm:"column:agent_token;type:text"`
|
||||||
|
IsAgentTokenGenerated bool `json:"isAgentTokenGenerated" gorm:"column:is_agent_token_generated;not null;default:false"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) Validate() error {
|
func (d *Database) Validate() error {
|
||||||
@@ -71,8 +74,19 @@ func (d *Database) Validate() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) ValidateUpdate(old, new Database) error {
|
func (d *Database) ValidateUpdate(old, new Database) error {
|
||||||
|
// Database type cannot be changed after creation — the entire backup
|
||||||
|
// structure (storage files, schedulers, WAL hierarchy, etc.) is tied to
|
||||||
|
// the type at creation time. Recreating that state automatically is
|
||||||
|
// error-prone; it is safer for the user to create a new database and
|
||||||
|
// remove the old one.
|
||||||
if old.Type != new.Type {
|
if old.Type != new.Type {
|
||||||
return errors.New("database type is not allowed to change")
|
return errors.New("database type cannot be changed; create a new database instead")
|
||||||
|
}
|
||||||
|
|
||||||
|
if old.Type == DatabaseTypePostgres && old.Postgresql != nil && new.Postgresql != nil {
|
||||||
|
if err := new.Postgresql.ValidateUpdate(old.Postgresql); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -244,6 +244,18 @@ func (r *DatabaseRepository) GetAllDatabases() ([]*Database, error) {
|
|||||||
return databases, nil
|
return databases, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *DatabaseRepository) FindByAgentTokenHash(hash string) (*Database, error) {
|
||||||
|
var database Database
|
||||||
|
|
||||||
|
if err := storage.GetDb().
|
||||||
|
Where("agent_token = ?", hash).
|
||||||
|
First(&database).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &database, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *DatabaseRepository) GetDatabasesIDsByNotifierID(
|
func (r *DatabaseRepository) GetDatabasesIDsByNotifierID(
|
||||||
notifierID uuid.UUID,
|
notifierID uuid.UUID,
|
||||||
) ([]uuid.UUID, error) {
|
) ([]uuid.UUID, error) {
|
||||||
|
|||||||
@@ -2,9 +2,11 @@ package databases
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"databasus-backend/internal/config"
|
"databasus-backend/internal/config"
|
||||||
@@ -87,21 +89,8 @@ func (s *DatabaseService) CreateDatabase(
|
|||||||
return nil, fmt.Errorf("failed to auto-detect database data: %w", err)
|
return nil, fmt.Errorf("failed to auto-detect database data: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.GetEnv().IsCloud {
|
if err := s.verifyReadOnlyUserIfNeeded(database); err != nil {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
return nil, err
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
isReadOnly, permissions, err := database.IsUserReadOnly(ctx, s.logger, s.fieldEncryptor)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to verify user permissions: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !isReadOnly {
|
|
||||||
return nil, fmt.Errorf(
|
|
||||||
"in cloud mode, only read-only database users are allowed (user has permissions: %v)",
|
|
||||||
permissions,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := database.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
|
if err := database.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
|
||||||
@@ -171,25 +160,8 @@ func (s *DatabaseService) UpdateDatabase(
|
|||||||
return fmt.Errorf("failed to auto-detect database data: %w", err)
|
return fmt.Errorf("failed to auto-detect database data: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.GetEnv().IsCloud {
|
if err := s.verifyReadOnlyUserIfNeeded(existingDatabase); err != nil {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
return err
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
isReadOnly, permissions, err := existingDatabase.IsUserReadOnly(
|
|
||||||
ctx,
|
|
||||||
s.logger,
|
|
||||||
s.fieldEncryptor,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to verify user permissions: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !isReadOnly {
|
|
||||||
return fmt.Errorf(
|
|
||||||
"in cloud mode, only read-only database users are allowed (user has permissions: %v)",
|
|
||||||
permissions,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
oldName := existingDatabase.Name
|
oldName := existingDatabase.Name
|
||||||
@@ -485,6 +457,7 @@ func (s *DatabaseService) CopyDatabase(
|
|||||||
newDatabase.Postgresql = &postgresql.PostgresqlDatabase{
|
newDatabase.Postgresql = &postgresql.PostgresqlDatabase{
|
||||||
ID: uuid.Nil,
|
ID: uuid.Nil,
|
||||||
DatabaseID: nil,
|
DatabaseID: nil,
|
||||||
|
BackupType: existingDatabase.Postgresql.BackupType,
|
||||||
Version: existingDatabase.Postgresql.Version,
|
Version: existingDatabase.Postgresql.Version,
|
||||||
Host: existingDatabase.Postgresql.Host,
|
Host: existingDatabase.Postgresql.Host,
|
||||||
Port: existingDatabase.Postgresql.Port,
|
Port: existingDatabase.Postgresql.Port,
|
||||||
@@ -638,6 +611,71 @@ func (s *DatabaseService) SetHealthStatus(
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DatabaseService) RegenerateAgentToken(
|
||||||
|
user *users_models.User,
|
||||||
|
databaseID uuid.UUID,
|
||||||
|
) (string, error) {
|
||||||
|
database, err := s.dbRepository.FindByID(databaseID)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if database.WorkspaceID == nil {
|
||||||
|
return "", errors.New("cannot regenerate token for database without workspace")
|
||||||
|
}
|
||||||
|
|
||||||
|
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if !canManage {
|
||||||
|
return "", errors.New(
|
||||||
|
"insufficient permissions to regenerate agent token for this database",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
plainToken := strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||||
|
tokenHash := hashAgentToken(plainToken)
|
||||||
|
|
||||||
|
database.AgentToken = &tokenHash
|
||||||
|
database.IsAgentTokenGenerated = true
|
||||||
|
|
||||||
|
_, err = s.dbRepository.Save(database)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.auditLogService.WriteAuditLog(
|
||||||
|
fmt.Sprintf("Agent token regenerated for database: %s", database.Name),
|
||||||
|
&user.ID,
|
||||||
|
database.WorkspaceID,
|
||||||
|
)
|
||||||
|
|
||||||
|
return plainToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DatabaseService) VerifyAgentToken(token string) error {
|
||||||
|
hash := hashAgentToken(token)
|
||||||
|
|
||||||
|
_, err := s.dbRepository.FindByAgentTokenHash(hash)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("invalid token")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DatabaseService) GetDatabaseByAgentToken(token string) (*Database, error) {
|
||||||
|
hash := hashAgentToken(token)
|
||||||
|
|
||||||
|
partial, err := s.dbRepository.FindByAgentTokenHash(hash)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("invalid agent token")
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.dbRepository.FindByID(partial.ID)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *DatabaseService) OnBeforeWorkspaceDeletion(workspaceID uuid.UUID) error {
|
func (s *DatabaseService) OnBeforeWorkspaceDeletion(workspaceID uuid.UUID) error {
|
||||||
databases, err := s.dbRepository.FindByWorkspaceID(workspaceID)
|
databases, err := s.dbRepository.FindByWorkspaceID(workspaceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -809,3 +847,36 @@ func (s *DatabaseService) CreateReadOnlyUser(
|
|||||||
|
|
||||||
return username, password, nil
|
return username, password, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DatabaseService) verifyReadOnlyUserIfNeeded(database *Database) error {
|
||||||
|
if !config.GetEnv().IsCloud {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if database.Postgresql != nil &&
|
||||||
|
database.Postgresql.BackupType == postgresql.PostgresBackupTypeWalV1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
isReadOnly, permissions, err := database.IsUserReadOnly(ctx, s.logger, s.fieldEncryptor)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to verify user permissions: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isReadOnly {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"in cloud mode, only read-only database users are allowed (user has permissions: %v)",
|
||||||
|
permissions,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func hashAgentToken(token string) string {
|
||||||
|
hash := sha256.Sum256([]byte(token))
|
||||||
|
return fmt.Sprintf("%x", hash)
|
||||||
|
}
|
||||||
|
|||||||
@@ -79,6 +79,38 @@ func (i *Interval) ShouldTriggerBackup(now time.Time, lastBackupTime *time.Time)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NextTriggerTime computes the next time a backup should trigger based on the interval and last backup time.
|
||||||
|
// Returns nil when a backup is due immediately (no previous backup exists).
|
||||||
|
func (i *Interval) NextTriggerTime(now time.Time, lastBackupTime *time.Time) *time.Time {
|
||||||
|
if lastBackupTime == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch i.Interval {
|
||||||
|
case IntervalHourly:
|
||||||
|
next := lastBackupTime.Add(time.Hour)
|
||||||
|
return &next
|
||||||
|
|
||||||
|
case IntervalDaily:
|
||||||
|
next := i.nextDailyTrigger(now)
|
||||||
|
return &next
|
||||||
|
|
||||||
|
case IntervalWeekly:
|
||||||
|
next := i.nextWeeklyTrigger(now)
|
||||||
|
return &next
|
||||||
|
|
||||||
|
case IntervalMonthly:
|
||||||
|
next := i.nextMonthlyTrigger(now)
|
||||||
|
return &next
|
||||||
|
|
||||||
|
case IntervalCron:
|
||||||
|
return i.nextCronTrigger(*lastBackupTime)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (i *Interval) Copy() *Interval {
|
func (i *Interval) Copy() *Interval {
|
||||||
return &Interval{
|
return &Interval{
|
||||||
ID: uuid.Nil,
|
ID: uuid.Nil,
|
||||||
@@ -240,6 +272,99 @@ func (i *Interval) shouldTriggerCron(now, lastBackup time.Time) bool {
|
|||||||
return now.After(nextAfterLastBackup) || now.Equal(nextAfterLastBackup)
|
return now.After(nextAfterLastBackup) || now.Equal(nextAfterLastBackup)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i *Interval) nextDailyTrigger(now time.Time) time.Time {
|
||||||
|
t, err := time.Parse("15:04", *i.TimeOfDay)
|
||||||
|
if err != nil {
|
||||||
|
return now
|
||||||
|
}
|
||||||
|
|
||||||
|
todaySlot := time.Date(
|
||||||
|
now.Year(), now.Month(), now.Day(),
|
||||||
|
t.Hour(), t.Minute(), 0, 0, now.Location(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if now.Before(todaySlot) {
|
||||||
|
return todaySlot
|
||||||
|
}
|
||||||
|
|
||||||
|
return todaySlot.AddDate(0, 0, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Interval) nextWeeklyTrigger(now time.Time) time.Time {
|
||||||
|
targetWd := time.Weekday(0)
|
||||||
|
if i.Weekday != nil {
|
||||||
|
targetWd = time.Weekday(*i.Weekday)
|
||||||
|
}
|
||||||
|
|
||||||
|
startOfWeek := getStartOfWeek(now)
|
||||||
|
|
||||||
|
var daysFromMonday int
|
||||||
|
if targetWd == time.Sunday {
|
||||||
|
daysFromMonday = 6
|
||||||
|
} else {
|
||||||
|
daysFromMonday = int(targetWd) - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
targetThisWeek := startOfWeek.AddDate(0, 0, daysFromMonday)
|
||||||
|
|
||||||
|
if i.TimeOfDay != nil {
|
||||||
|
t, err := time.Parse("15:04", *i.TimeOfDay)
|
||||||
|
if err == nil {
|
||||||
|
targetThisWeek = time.Date(
|
||||||
|
targetThisWeek.Year(), targetThisWeek.Month(), targetThisWeek.Day(),
|
||||||
|
t.Hour(), t.Minute(), 0, 0, targetThisWeek.Location(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if now.Before(targetThisWeek) {
|
||||||
|
return targetThisWeek
|
||||||
|
}
|
||||||
|
|
||||||
|
return targetThisWeek.AddDate(0, 0, 7)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Interval) nextMonthlyTrigger(now time.Time) time.Time {
|
||||||
|
day := 1
|
||||||
|
if i.DayOfMonth != nil {
|
||||||
|
day = *i.DayOfMonth
|
||||||
|
}
|
||||||
|
|
||||||
|
targetThisMonth := time.Date(now.Year(), now.Month(), day, 0, 0, 0, 0, now.Location())
|
||||||
|
|
||||||
|
if i.TimeOfDay != nil {
|
||||||
|
t, err := time.Parse("15:04", *i.TimeOfDay)
|
||||||
|
if err == nil {
|
||||||
|
targetThisMonth = time.Date(
|
||||||
|
targetThisMonth.Year(), targetThisMonth.Month(), targetThisMonth.Day(),
|
||||||
|
t.Hour(), t.Minute(), 0, 0, targetThisMonth.Location(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if now.Before(targetThisMonth) {
|
||||||
|
return targetThisMonth
|
||||||
|
}
|
||||||
|
|
||||||
|
return targetThisMonth.AddDate(0, 1, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Interval) nextCronTrigger(lastBackup time.Time) *time.Time {
|
||||||
|
if i.CronExpression == nil || *i.CronExpression == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
|
||||||
|
schedule, err := parser.Parse(*i.CronExpression)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
next := schedule.Next(lastBackup)
|
||||||
|
|
||||||
|
return &next
|
||||||
|
}
|
||||||
|
|
||||||
func (i *Interval) validateCronExpression(expr string) error {
|
func (i *Interval) validateCronExpression(expr string) error {
|
||||||
parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
|
parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
|
||||||
_, err := parser.Parse(expr)
|
_, err := parser.Parse(expr)
|
||||||
|
|||||||
@@ -721,3 +721,265 @@ func TestInterval_Validate(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestInterval_NextTriggerTime_NilLastBackup(t *testing.T) {
|
||||||
|
now := time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
t.Run("Hourly with nil lastBackup returns nil", func(t *testing.T) {
|
||||||
|
interval := &Interval{ID: uuid.New(), Interval: IntervalHourly}
|
||||||
|
result := interval.NextTriggerTime(now, nil)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Daily with nil lastBackup returns nil", func(t *testing.T) {
|
||||||
|
timeOfDay := "09:00"
|
||||||
|
interval := &Interval{ID: uuid.New(), Interval: IntervalDaily, TimeOfDay: &timeOfDay}
|
||||||
|
result := interval.NextTriggerTime(now, nil)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Weekly with nil lastBackup returns nil", func(t *testing.T) {
|
||||||
|
timeOfDay := "15:00"
|
||||||
|
weekday := 3
|
||||||
|
interval := &Interval{
|
||||||
|
ID: uuid.New(),
|
||||||
|
Interval: IntervalWeekly,
|
||||||
|
TimeOfDay: &timeOfDay,
|
||||||
|
Weekday: &weekday,
|
||||||
|
}
|
||||||
|
result := interval.NextTriggerTime(now, nil)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Monthly with nil lastBackup returns nil", func(t *testing.T) {
|
||||||
|
timeOfDay := "08:00"
|
||||||
|
dayOfMonth := 10
|
||||||
|
interval := &Interval{
|
||||||
|
ID: uuid.New(),
|
||||||
|
Interval: IntervalMonthly,
|
||||||
|
TimeOfDay: &timeOfDay,
|
||||||
|
DayOfMonth: &dayOfMonth,
|
||||||
|
}
|
||||||
|
result := interval.NextTriggerTime(now, nil)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Cron with nil lastBackup returns nil", func(t *testing.T) {
|
||||||
|
cronExpr := "0 2 * * *"
|
||||||
|
interval := &Interval{ID: uuid.New(), Interval: IntervalCron, CronExpression: &cronExpr}
|
||||||
|
result := interval.NextTriggerTime(now, nil)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInterval_NextTriggerTime_Hourly(t *testing.T) {
|
||||||
|
interval := &Interval{ID: uuid.New(), Interval: IntervalHourly}
|
||||||
|
now := time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
t.Run("Returns lastBackup + 1 hour", func(t *testing.T) {
|
||||||
|
lastBackup := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC)
|
||||||
|
result := interval.NextTriggerTime(now, &lastBackup)
|
||||||
|
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.Equal(t, time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC), *result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Returns future time when last backup was recent", func(t *testing.T) {
|
||||||
|
lastBackup := time.Date(2024, 1, 15, 11, 30, 0, 0, time.UTC)
|
||||||
|
result := interval.NextTriggerTime(now, &lastBackup)
|
||||||
|
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.Equal(t, time.Date(2024, 1, 15, 12, 30, 0, 0, time.UTC), *result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInterval_NextTriggerTime_Daily(t *testing.T) {
|
||||||
|
timeOfDay := "09:00"
|
||||||
|
interval := &Interval{ID: uuid.New(), Interval: IntervalDaily, TimeOfDay: &timeOfDay}
|
||||||
|
|
||||||
|
t.Run("Before today's slot: returns today's slot", func(t *testing.T) {
|
||||||
|
now := time.Date(2024, 1, 15, 8, 0, 0, 0, time.UTC)
|
||||||
|
lastBackup := time.Date(2024, 1, 14, 9, 0, 0, 0, time.UTC)
|
||||||
|
result := interval.NextTriggerTime(now, &lastBackup)
|
||||||
|
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.Equal(t, time.Date(2024, 1, 15, 9, 0, 0, 0, time.UTC), *result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("After today's slot: returns tomorrow's slot", func(t *testing.T) {
|
||||||
|
now := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC)
|
||||||
|
lastBackup := time.Date(2024, 1, 15, 9, 0, 0, 0, time.UTC)
|
||||||
|
result := interval.NextTriggerTime(now, &lastBackup)
|
||||||
|
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.Equal(t, time.Date(2024, 1, 16, 9, 0, 0, 0, time.UTC), *result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Exactly at today's slot: returns tomorrow's slot", func(t *testing.T) {
|
||||||
|
now := time.Date(2024, 1, 15, 9, 0, 0, 0, time.UTC)
|
||||||
|
lastBackup := time.Date(2024, 1, 14, 9, 0, 0, 0, time.UTC)
|
||||||
|
result := interval.NextTriggerTime(now, &lastBackup)
|
||||||
|
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.Equal(t, time.Date(2024, 1, 16, 9, 0, 0, 0, time.UTC), *result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInterval_NextTriggerTime_Weekly(t *testing.T) {
|
||||||
|
timeOfDay := "15:00"
|
||||||
|
weekday := 3 // Wednesday
|
||||||
|
interval := &Interval{
|
||||||
|
ID: uuid.New(),
|
||||||
|
Interval: IntervalWeekly,
|
||||||
|
TimeOfDay: &timeOfDay,
|
||||||
|
Weekday: &weekday,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Before this week's target: returns this week's target", func(t *testing.T) {
|
||||||
|
// Tuesday Jan 16, 2024
|
||||||
|
now := time.Date(2024, 1, 16, 10, 0, 0, 0, time.UTC)
|
||||||
|
lastBackup := time.Date(2024, 1, 10, 15, 0, 0, 0, time.UTC)
|
||||||
|
result := interval.NextTriggerTime(now, &lastBackup)
|
||||||
|
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
// Wednesday Jan 17 at 15:00
|
||||||
|
assert.Equal(t, time.Date(2024, 1, 17, 15, 0, 0, 0, time.UTC), *result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("After this week's target: returns next week's target", func(t *testing.T) {
|
||||||
|
// Thursday Jan 18, 2024
|
||||||
|
now := time.Date(2024, 1, 18, 10, 0, 0, 0, time.UTC)
|
||||||
|
lastBackup := time.Date(2024, 1, 17, 15, 0, 0, 0, time.UTC)
|
||||||
|
result := interval.NextTriggerTime(now, &lastBackup)
|
||||||
|
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
// Next Wednesday Jan 24 at 15:00
|
||||||
|
assert.Equal(t, time.Date(2024, 1, 24, 15, 0, 0, 0, time.UTC), *result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Friday interval: returns correct target", func(t *testing.T) {
|
||||||
|
fridayTimeOfDay := "00:00"
|
||||||
|
fridayWeekday := 5 // Friday
|
||||||
|
fridayInterval := &Interval{
|
||||||
|
ID: uuid.New(),
|
||||||
|
Interval: IntervalWeekly,
|
||||||
|
TimeOfDay: &fridayTimeOfDay,
|
||||||
|
Weekday: &fridayWeekday,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wednesday Jan 17, 2024
|
||||||
|
now := time.Date(2024, 1, 17, 10, 0, 0, 0, time.UTC)
|
||||||
|
lastBackup := time.Date(2024, 1, 12, 0, 0, 0, 0, time.UTC)
|
||||||
|
result := fridayInterval.NextTriggerTime(now, &lastBackup)
|
||||||
|
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
// Friday Jan 19 at 00:00
|
||||||
|
assert.Equal(t, time.Date(2024, 1, 19, 0, 0, 0, 0, time.UTC), *result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInterval_NextTriggerTime_Monthly(t *testing.T) {
|
||||||
|
timeOfDay := "08:00"
|
||||||
|
dayOfMonth := 10
|
||||||
|
interval := &Interval{
|
||||||
|
ID: uuid.New(),
|
||||||
|
Interval: IntervalMonthly,
|
||||||
|
TimeOfDay: &timeOfDay,
|
||||||
|
DayOfMonth: &dayOfMonth,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Before this month's target: returns this month's target", func(t *testing.T) {
|
||||||
|
now := time.Date(2024, 1, 5, 10, 0, 0, 0, time.UTC)
|
||||||
|
lastBackup := time.Date(2023, 12, 10, 8, 0, 0, 0, time.UTC)
|
||||||
|
result := interval.NextTriggerTime(now, &lastBackup)
|
||||||
|
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.Equal(t, time.Date(2024, 1, 10, 8, 0, 0, 0, time.UTC), *result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("After this month's target: returns next month's target", func(t *testing.T) {
|
||||||
|
now := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC)
|
||||||
|
lastBackup := time.Date(2024, 1, 10, 8, 0, 0, 0, time.UTC)
|
||||||
|
result := interval.NextTriggerTime(now, &lastBackup)
|
||||||
|
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.Equal(t, time.Date(2024, 2, 10, 8, 0, 0, 0, time.UTC), *result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Exactly at this month's target: returns next month's target", func(t *testing.T) {
|
||||||
|
now := time.Date(2024, 1, 10, 8, 0, 0, 0, time.UTC)
|
||||||
|
lastBackup := time.Date(2023, 12, 10, 8, 0, 0, 0, time.UTC)
|
||||||
|
result := interval.NextTriggerTime(now, &lastBackup)
|
||||||
|
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.Equal(t, time.Date(2024, 2, 10, 8, 0, 0, 0, time.UTC), *result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInterval_NextTriggerTime_Cron(t *testing.T) {
|
||||||
|
t.Run("Daily cron: returns next trigger after lastBackup", func(t *testing.T) {
|
||||||
|
cronExpr := "0 2 * * *" // Daily at 2:00 AM
|
||||||
|
interval := &Interval{ID: uuid.New(), Interval: IntervalCron, CronExpression: &cronExpr}
|
||||||
|
|
||||||
|
now := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC)
|
||||||
|
lastBackup := time.Date(2024, 1, 14, 2, 0, 0, 0, time.UTC)
|
||||||
|
result := interval.NextTriggerTime(now, &lastBackup)
|
||||||
|
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.Equal(t, time.Date(2024, 1, 15, 2, 0, 0, 0, time.UTC), *result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Complex cron: 1st and 15th at 4:30", func(t *testing.T) {
|
||||||
|
cronExpr := "30 4 1,15 * *"
|
||||||
|
interval := &Interval{ID: uuid.New(), Interval: IntervalCron, CronExpression: &cronExpr}
|
||||||
|
|
||||||
|
now := time.Date(2024, 1, 10, 10, 0, 0, 0, time.UTC)
|
||||||
|
lastBackup := time.Date(2024, 1, 1, 4, 30, 0, 0, time.UTC)
|
||||||
|
result := interval.NextTriggerTime(now, &lastBackup)
|
||||||
|
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.Equal(t, time.Date(2024, 1, 15, 4, 30, 0, 0, time.UTC), *result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Invalid cron expression returns nil", func(t *testing.T) {
|
||||||
|
invalidCron := "invalid cron"
|
||||||
|
interval := &Interval{ID: uuid.New(), Interval: IntervalCron, CronExpression: &invalidCron}
|
||||||
|
|
||||||
|
now := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC)
|
||||||
|
lastBackup := time.Date(2024, 1, 14, 10, 0, 0, 0, time.UTC)
|
||||||
|
result := interval.NextTriggerTime(now, &lastBackup)
|
||||||
|
|
||||||
|
assert.Nil(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Empty cron expression returns nil", func(t *testing.T) {
|
||||||
|
emptyCron := ""
|
||||||
|
interval := &Interval{ID: uuid.New(), Interval: IntervalCron, CronExpression: &emptyCron}
|
||||||
|
|
||||||
|
now := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC)
|
||||||
|
lastBackup := time.Date(2024, 1, 14, 10, 0, 0, 0, time.UTC)
|
||||||
|
result := interval.NextTriggerTime(now, &lastBackup)
|
||||||
|
|
||||||
|
assert.Nil(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Nil cron expression returns nil", func(t *testing.T) {
|
||||||
|
interval := &Interval{ID: uuid.New(), Interval: IntervalCron, CronExpression: nil}
|
||||||
|
|
||||||
|
now := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC)
|
||||||
|
lastBackup := time.Date(2024, 1, 14, 10, 0, 0, 0, time.UTC)
|
||||||
|
result := interval.NextTriggerTime(now, &lastBackup)
|
||||||
|
|
||||||
|
assert.Nil(t, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInterval_NextTriggerTime_UnknownInterval(t *testing.T) {
|
||||||
|
interval := &Interval{ID: uuid.New(), Interval: IntervalType("UNKNOWN")}
|
||||||
|
|
||||||
|
now := time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||||
|
lastBackup := time.Date(2024, 1, 14, 12, 0, 0, 0, time.UTC)
|
||||||
|
result := interval.NextTriggerTime(now, &lastBackup)
|
||||||
|
|
||||||
|
assert.Nil(t, result)
|
||||||
|
}
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import (
|
|||||||
|
|
||||||
env_config "databasus-backend/internal/config"
|
env_config "databasus-backend/internal/config"
|
||||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||||
"databasus-backend/internal/features/backups/backups"
|
backups_controllers "databasus-backend/internal/features/backups/backups/controllers"
|
||||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||||
backups_config "databasus-backend/internal/features/backups/config"
|
backups_config "databasus-backend/internal/features/backups/config"
|
||||||
"databasus-backend/internal/features/databases"
|
"databasus-backend/internal/features/databases"
|
||||||
@@ -440,7 +440,7 @@ func Test_CancelRestore_InProgressRestore_SuccessfullyCancelled(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
backup := backups_controllers.CreateTestBackup(database.ID, storage.ID)
|
||||||
|
|
||||||
mockUsecase := &restoring.MockBlockingRestoreUsecase{
|
mockUsecase := &restoring.MockBlockingRestoreUsecase{
|
||||||
StartedChan: make(chan bool, 1),
|
StartedChan: make(chan bool, 1),
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||||
"databasus-backend/internal/features/backups/backups"
|
|
||||||
"databasus-backend/internal/features/backups/backups/backuping"
|
"databasus-backend/internal/features/backups/backups/backuping"
|
||||||
|
backups_services "databasus-backend/internal/features/backups/backups/services"
|
||||||
backups_config "databasus-backend/internal/features/backups/config"
|
backups_config "databasus-backend/internal/features/backups/config"
|
||||||
"databasus-backend/internal/features/databases"
|
"databasus-backend/internal/features/databases"
|
||||||
"databasus-backend/internal/features/disk"
|
"databasus-backend/internal/features/disk"
|
||||||
@@ -21,7 +21,7 @@ import (
|
|||||||
|
|
||||||
var restoreRepository = &restores_core.RestoreRepository{}
|
var restoreRepository = &restores_core.RestoreRepository{}
|
||||||
var restoreService = &RestoreService{
|
var restoreService = &RestoreService{
|
||||||
backups.GetBackupService(),
|
backups_services.GetBackupService(),
|
||||||
restoreRepository,
|
restoreRepository,
|
||||||
storages.GetStorageService(),
|
storages.GetStorageService(),
|
||||||
backups_config.GetBackupConfigService(),
|
backups_config.GetBackupConfigService(),
|
||||||
@@ -51,7 +51,7 @@ func SetupDependencies() {
|
|||||||
wasAlreadySetup := isSetup.Load()
|
wasAlreadySetup := isSetup.Load()
|
||||||
|
|
||||||
setupOnce.Do(func() {
|
setupOnce.Do(func() {
|
||||||
backups.GetBackupService().AddBackupRemoveListener(restoreService)
|
backups_services.GetBackupService().AddBackupRemoveListener(restoreService)
|
||||||
backuping.GetBackupCleaner().AddBackupRemoveListener(restoreService)
|
backuping.GetBackupCleaner().AddBackupRemoveListener(restoreService)
|
||||||
|
|
||||||
isSetup.Store(true)
|
isSetup.Store(true)
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
|
||||||
"databasus-backend/internal/features/backups/backups"
|
backups_services "databasus-backend/internal/features/backups/backups/services"
|
||||||
backups_config "databasus-backend/internal/features/backups/config"
|
backups_config "databasus-backend/internal/features/backups/config"
|
||||||
"databasus-backend/internal/features/databases"
|
"databasus-backend/internal/features/databases"
|
||||||
restores_core "databasus-backend/internal/features/restores/core"
|
restores_core "databasus-backend/internal/features/restores/core"
|
||||||
@@ -39,37 +39,37 @@ var restoreDatabaseCache = cache_utils.NewCacheUtil[RestoreDatabaseCache](
|
|||||||
var restoreCancelManager = tasks_cancellation.GetTaskCancelManager()
|
var restoreCancelManager = tasks_cancellation.GetTaskCancelManager()
|
||||||
|
|
||||||
var restorerNode = &RestorerNode{
|
var restorerNode = &RestorerNode{
|
||||||
nodeID: uuid.New(),
|
uuid.New(),
|
||||||
databaseService: databases.GetDatabaseService(),
|
databases.GetDatabaseService(),
|
||||||
backupService: backups.GetBackupService(),
|
backups_services.GetBackupService(),
|
||||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
encryption.GetFieldEncryptor(),
|
||||||
restoreRepository: restoreRepository,
|
restoreRepository,
|
||||||
backupConfigService: backups_config.GetBackupConfigService(),
|
backups_config.GetBackupConfigService(),
|
||||||
storageService: storages.GetStorageService(),
|
storages.GetStorageService(),
|
||||||
restoreNodesRegistry: restoreNodesRegistry,
|
restoreNodesRegistry,
|
||||||
logger: logger.GetLogger(),
|
logger.GetLogger(),
|
||||||
restoreBackupUsecase: usecases.GetRestoreBackupUsecase(),
|
usecases.GetRestoreBackupUsecase(),
|
||||||
cacheUtil: restoreDatabaseCache,
|
restoreDatabaseCache,
|
||||||
restoreCancelManager: restoreCancelManager,
|
restoreCancelManager,
|
||||||
lastHeartbeat: time.Time{},
|
time.Time{},
|
||||||
runOnce: sync.Once{},
|
sync.Once{},
|
||||||
hasRun: atomic.Bool{},
|
atomic.Bool{},
|
||||||
}
|
}
|
||||||
|
|
||||||
var restoresScheduler = &RestoresScheduler{
|
var restoresScheduler = &RestoresScheduler{
|
||||||
restoreRepository: restoreRepository,
|
restoreRepository,
|
||||||
backupService: backups.GetBackupService(),
|
backups_services.GetBackupService(),
|
||||||
storageService: storages.GetStorageService(),
|
storages.GetStorageService(),
|
||||||
backupConfigService: backups_config.GetBackupConfigService(),
|
backups_config.GetBackupConfigService(),
|
||||||
restoreNodesRegistry: restoreNodesRegistry,
|
restoreNodesRegistry,
|
||||||
lastCheckTime: time.Now().UTC(),
|
time.Now().UTC(),
|
||||||
logger: logger.GetLogger(),
|
logger.GetLogger(),
|
||||||
restoreToNodeRelations: make(map[uuid.UUID]RestoreToNodeRelation),
|
make(map[uuid.UUID]RestoreToNodeRelation),
|
||||||
restorerNode: restorerNode,
|
restorerNode,
|
||||||
cacheUtil: restoreDatabaseCache,
|
restoreDatabaseCache,
|
||||||
completionSubscriptionID: uuid.Nil,
|
uuid.Nil,
|
||||||
runOnce: sync.Once{},
|
sync.Once{},
|
||||||
hasRun: atomic.Bool{},
|
atomic.Bool{},
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetRestoresScheduler() *RestoresScheduler {
|
func GetRestoresScheduler() *RestoresScheduler {
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
|
||||||
"databasus-backend/internal/config"
|
"databasus-backend/internal/config"
|
||||||
"databasus-backend/internal/features/backups/backups"
|
backups_services "databasus-backend/internal/features/backups/backups/services"
|
||||||
backups_config "databasus-backend/internal/features/backups/config"
|
backups_config "databasus-backend/internal/features/backups/config"
|
||||||
"databasus-backend/internal/features/databases"
|
"databasus-backend/internal/features/databases"
|
||||||
restores_core "databasus-backend/internal/features/restores/core"
|
restores_core "databasus-backend/internal/features/restores/core"
|
||||||
@@ -32,7 +32,7 @@ type RestorerNode struct {
|
|||||||
nodeID uuid.UUID
|
nodeID uuid.UUID
|
||||||
|
|
||||||
databaseService *databases.DatabaseService
|
databaseService *databases.DatabaseService
|
||||||
backupService *backups.BackupService
|
backupService *backups_services.BackupService
|
||||||
fieldEncryptor util_encryption.FieldEncryptor
|
fieldEncryptor util_encryption.FieldEncryptor
|
||||||
restoreRepository *restores_core.RestoreRepository
|
restoreRepository *restores_core.RestoreRepository
|
||||||
backupConfigService *backups_config.BackupConfigService
|
backupConfigService *backups_config.BackupConfigService
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"databasus-backend/internal/config"
|
"databasus-backend/internal/config"
|
||||||
"databasus-backend/internal/features/backups/backups"
|
backups_controllers "databasus-backend/internal/features/backups/backups/controllers"
|
||||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||||
backups_config "databasus-backend/internal/features/backups/config"
|
backups_config "databasus-backend/internal/features/backups/config"
|
||||||
"databasus-backend/internal/features/databases"
|
"databasus-backend/internal/features/databases"
|
||||||
@@ -58,7 +58,7 @@ func Test_MakeRestore_WhenCacheMissed_RestoreFails(t *testing.T) {
|
|||||||
cache_utils.ClearAllCache()
|
cache_utils.ClearAllCache()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
backup := backups_controllers.CreateTestBackup(database.ID, storage.ID)
|
||||||
|
|
||||||
// Create restore but DON'T cache DB credentials
|
// Create restore but DON'T cache DB credentials
|
||||||
// Also don't set embedded DB fields to avoid schema issues
|
// Also don't set embedded DB fields to avoid schema issues
|
||||||
@@ -126,7 +126,7 @@ func Test_MakeRestore_WhenTaskStarts_CacheDeletedImmediately(t *testing.T) {
|
|||||||
cache_utils.ClearAllCache()
|
cache_utils.ClearAllCache()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
backup := backups_controllers.CreateTestBackup(database.ID, storage.ID)
|
||||||
|
|
||||||
// Create restore with cached DB credentials
|
// Create restore with cached DB credentials
|
||||||
// Don't set embedded DB fields in the restore model itself
|
// Don't set embedded DB fields in the restore model itself
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
|
||||||
"databasus-backend/internal/config"
|
"databasus-backend/internal/config"
|
||||||
"databasus-backend/internal/features/backups/backups"
|
backups_services "databasus-backend/internal/features/backups/backups/services"
|
||||||
backups_config "databasus-backend/internal/features/backups/config"
|
backups_config "databasus-backend/internal/features/backups/config"
|
||||||
restores_core "databasus-backend/internal/features/restores/core"
|
restores_core "databasus-backend/internal/features/restores/core"
|
||||||
"databasus-backend/internal/features/storages"
|
"databasus-backend/internal/features/storages"
|
||||||
@@ -26,7 +26,7 @@ const (
|
|||||||
|
|
||||||
type RestoresScheduler struct {
|
type RestoresScheduler struct {
|
||||||
restoreRepository *restores_core.RestoreRepository
|
restoreRepository *restores_core.RestoreRepository
|
||||||
backupService *backups.BackupService
|
backupService *backups_services.BackupService
|
||||||
storageService *storages.StorageService
|
storageService *storages.StorageService
|
||||||
backupConfigService *backups_config.BackupConfigService
|
backupConfigService *backups_config.BackupConfigService
|
||||||
restoreNodesRegistry *RestoreNodesRegistry
|
restoreNodesRegistry *RestoreNodesRegistry
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"databasus-backend/internal/config"
|
"databasus-backend/internal/config"
|
||||||
"databasus-backend/internal/features/backups/backups"
|
backups_controllers "databasus-backend/internal/features/backups/backups/controllers"
|
||||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||||
backups_config "databasus-backend/internal/features/backups/config"
|
backups_config "databasus-backend/internal/features/backups/config"
|
||||||
"databasus-backend/internal/features/databases"
|
"databasus-backend/internal/features/databases"
|
||||||
@@ -68,7 +68,7 @@ func Test_CheckDeadNodesAndFailRestores_NodeDies_FailsRestoreAndCleansUpRegistry
|
|||||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||||
|
|
||||||
// Create a test backup
|
// Create a test backup
|
||||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
backup := backups_controllers.CreateTestBackup(database.ID, storage.ID)
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
// Register mock node without subscribing to restores (simulates node crash after registration)
|
// Register mock node without subscribing to restores (simulates node crash after registration)
|
||||||
@@ -171,7 +171,7 @@ func Test_OnRestoreCompleted_TaskIsNotRestore_SkipsProcessing(t *testing.T) {
|
|||||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||||
|
|
||||||
// Create a test backup
|
// Create a test backup
|
||||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
backup := backups_controllers.CreateTestBackup(database.ID, storage.ID)
|
||||||
|
|
||||||
// Register mock node
|
// Register mock node
|
||||||
mockNodeID = uuid.New()
|
mockNodeID = uuid.New()
|
||||||
@@ -357,7 +357,7 @@ func Test_FailRestoresInProgress_SchedulerStarts_UpdatesStatus(t *testing.T) {
|
|||||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||||
|
|
||||||
// Create a test backup
|
// Create a test backup
|
||||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
backup := backups_controllers.CreateTestBackup(database.ID, storage.ID)
|
||||||
|
|
||||||
// Create two in-progress restores that should be failed on scheduler restart
|
// Create two in-progress restores that should be failed on scheduler restart
|
||||||
restore1 := &restores_core.Restore{
|
restore1 := &restores_core.Restore{
|
||||||
@@ -465,7 +465,7 @@ func Test_StartRestore_RestoreCompletes_DecrementsActiveTaskCount(t *testing.T)
|
|||||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||||
|
|
||||||
// Create a test backup
|
// Create a test backup
|
||||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
backup := backups_controllers.CreateTestBackup(database.ID, storage.ID)
|
||||||
|
|
||||||
// Get initial active task count
|
// Get initial active task count
|
||||||
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
|
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
|
||||||
@@ -566,7 +566,7 @@ func Test_StartRestore_RestoreFails_DecrementsActiveTaskCount(t *testing.T) {
|
|||||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||||
|
|
||||||
// Create a test backup
|
// Create a test backup
|
||||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
backup := backups_controllers.CreateTestBackup(database.ID, storage.ID)
|
||||||
|
|
||||||
// Get initial active task count
|
// Get initial active task count
|
||||||
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
|
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
|
||||||
@@ -664,7 +664,7 @@ func Test_StartRestore_CredentialsStoredEncryptedInCache(t *testing.T) {
|
|||||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||||
|
|
||||||
// Create a test backup
|
// Create a test backup
|
||||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
backup := backups_controllers.CreateTestBackup(database.ID, storage.ID)
|
||||||
|
|
||||||
// Register mock node so scheduler can assign restore to it
|
// Register mock node so scheduler can assign restore to it
|
||||||
mockNodeID = uuid.New()
|
mockNodeID = uuid.New()
|
||||||
@@ -779,7 +779,7 @@ func Test_StartRestore_CredentialsRemovedAfterRestoreStarts(t *testing.T) {
|
|||||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||||
|
|
||||||
// Create a test backup
|
// Create a test backup
|
||||||
backup := backups.CreateTestBackup(database.ID, storage.ID)
|
backup := backups_controllers.CreateTestBackup(database.ID, storage.ID)
|
||||||
|
|
||||||
// Create restore with credentials
|
// Create restore with credentials
|
||||||
plaintextPassword := "test_password_456"
|
plaintextPassword := "test_password_456"
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
|
||||||
"databasus-backend/internal/config"
|
"databasus-backend/internal/config"
|
||||||
"databasus-backend/internal/features/backups/backups"
|
|
||||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||||
|
backups_services "databasus-backend/internal/features/backups/backups/services"
|
||||||
backups_config "databasus-backend/internal/features/backups/config"
|
backups_config "databasus-backend/internal/features/backups/config"
|
||||||
"databasus-backend/internal/features/databases"
|
"databasus-backend/internal/features/databases"
|
||||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||||
@@ -40,48 +40,48 @@ func CreateTestRouter() *gin.Engine {
|
|||||||
|
|
||||||
func CreateTestRestorerNode() *RestorerNode {
|
func CreateTestRestorerNode() *RestorerNode {
|
||||||
return &RestorerNode{
|
return &RestorerNode{
|
||||||
nodeID: uuid.New(),
|
uuid.New(),
|
||||||
databaseService: databases.GetDatabaseService(),
|
databases.GetDatabaseService(),
|
||||||
backupService: backups.GetBackupService(),
|
backups_services.GetBackupService(),
|
||||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
encryption.GetFieldEncryptor(),
|
||||||
restoreRepository: restoreRepository,
|
restoreRepository,
|
||||||
backupConfigService: backups_config.GetBackupConfigService(),
|
backups_config.GetBackupConfigService(),
|
||||||
storageService: storages.GetStorageService(),
|
storages.GetStorageService(),
|
||||||
restoreNodesRegistry: restoreNodesRegistry,
|
restoreNodesRegistry,
|
||||||
logger: logger.GetLogger(),
|
logger.GetLogger(),
|
||||||
restoreBackupUsecase: usecases.GetRestoreBackupUsecase(),
|
usecases.GetRestoreBackupUsecase(),
|
||||||
cacheUtil: restoreDatabaseCache,
|
restoreDatabaseCache,
|
||||||
restoreCancelManager: tasks_cancellation.GetTaskCancelManager(),
|
tasks_cancellation.GetTaskCancelManager(),
|
||||||
lastHeartbeat: time.Time{},
|
time.Time{},
|
||||||
runOnce: sync.Once{},
|
sync.Once{},
|
||||||
hasRun: atomic.Bool{},
|
atomic.Bool{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateTestRestorerNodeWithUsecase(usecase restores_core.RestoreBackupUsecase) *RestorerNode {
|
func CreateTestRestorerNodeWithUsecase(usecase restores_core.RestoreBackupUsecase) *RestorerNode {
|
||||||
return &RestorerNode{
|
return &RestorerNode{
|
||||||
nodeID: uuid.New(),
|
uuid.New(),
|
||||||
databaseService: databases.GetDatabaseService(),
|
databases.GetDatabaseService(),
|
||||||
backupService: backups.GetBackupService(),
|
backups_services.GetBackupService(),
|
||||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
encryption.GetFieldEncryptor(),
|
||||||
restoreRepository: restoreRepository,
|
restoreRepository,
|
||||||
backupConfigService: backups_config.GetBackupConfigService(),
|
backups_config.GetBackupConfigService(),
|
||||||
storageService: storages.GetStorageService(),
|
storages.GetStorageService(),
|
||||||
restoreNodesRegistry: restoreNodesRegistry,
|
restoreNodesRegistry,
|
||||||
logger: logger.GetLogger(),
|
logger.GetLogger(),
|
||||||
restoreBackupUsecase: usecase,
|
usecase,
|
||||||
cacheUtil: restoreDatabaseCache,
|
restoreDatabaseCache,
|
||||||
restoreCancelManager: tasks_cancellation.GetTaskCancelManager(),
|
tasks_cancellation.GetTaskCancelManager(),
|
||||||
lastHeartbeat: time.Time{},
|
time.Time{},
|
||||||
runOnce: sync.Once{},
|
sync.Once{},
|
||||||
hasRun: atomic.Bool{},
|
atomic.Bool{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateTestRestoresScheduler() *RestoresScheduler {
|
func CreateTestRestoresScheduler() *RestoresScheduler {
|
||||||
return &RestoresScheduler{
|
return &RestoresScheduler{
|
||||||
restoreRepository,
|
restoreRepository,
|
||||||
backups.GetBackupService(),
|
backups_services.GetBackupService(),
|
||||||
storages.GetStorageService(),
|
storages.GetStorageService(),
|
||||||
backups_config.GetBackupConfigService(),
|
backups_config.GetBackupConfigService(),
|
||||||
restoreNodesRegistry,
|
restoreNodesRegistry,
|
||||||
|
|||||||
@@ -3,8 +3,8 @@ package restores
|
|||||||
import (
|
import (
|
||||||
"databasus-backend/internal/config"
|
"databasus-backend/internal/config"
|
||||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||||
"databasus-backend/internal/features/backups/backups"
|
|
||||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||||
|
backups_services "databasus-backend/internal/features/backups/backups/services"
|
||||||
backups_config "databasus-backend/internal/features/backups/config"
|
backups_config "databasus-backend/internal/features/backups/config"
|
||||||
"databasus-backend/internal/features/databases"
|
"databasus-backend/internal/features/databases"
|
||||||
"databasus-backend/internal/features/disk"
|
"databasus-backend/internal/features/disk"
|
||||||
@@ -26,7 +26,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type RestoreService struct {
|
type RestoreService struct {
|
||||||
backupService *backups.BackupService
|
backupService *backups_services.BackupService
|
||||||
restoreRepository *restores_core.RestoreRepository
|
restoreRepository *restores_core.RestoreRepository
|
||||||
storageService *storages.StorageService
|
storageService *storages.StorageService
|
||||||
backupConfigService *backups_config.BackupConfigService
|
backupConfigService *backups_config.BackupConfigService
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
|
||||||
"databasus-backend/internal/features/backups/backups"
|
backups_controllers "databasus-backend/internal/features/backups/backups/controllers"
|
||||||
backups_config "databasus-backend/internal/features/backups/config"
|
backups_config "databasus-backend/internal/features/backups/config"
|
||||||
"databasus-backend/internal/features/databases"
|
"databasus-backend/internal/features/databases"
|
||||||
"databasus-backend/internal/features/restores/restoring"
|
"databasus-backend/internal/features/restores/restoring"
|
||||||
@@ -22,12 +22,12 @@ func CreateTestRouter() *gin.Engine {
|
|||||||
workspaces_controllers.GetMembershipController(),
|
workspaces_controllers.GetMembershipController(),
|
||||||
databases.GetDatabaseController(),
|
databases.GetDatabaseController(),
|
||||||
backups_config.GetBackupConfigController(),
|
backups_config.GetBackupConfigController(),
|
||||||
backups.GetBackupController(),
|
backups_controllers.GetBackupController(),
|
||||||
GetRestoreController(),
|
GetRestoreController(),
|
||||||
)
|
)
|
||||||
|
|
||||||
v1 := router.Group("/api/v1")
|
v1 := router.Group("/api/v1")
|
||||||
backups.GetBackupController().RegisterPublicRoutes(v1)
|
backups_controllers.GetBackupController().RegisterPublicRoutes(v1)
|
||||||
|
|
||||||
return router
|
return router
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -47,14 +47,15 @@ func (l *LocalStorage) SaveFile(
|
|||||||
|
|
||||||
logger.Info("Starting to save file to local storage", "fileName", fileName)
|
logger.Info("Starting to save file to local storage", "fileName", fileName)
|
||||||
|
|
||||||
|
tempFilePath := filepath.Join(config.GetEnv().TempFolder, fileName)
|
||||||
|
|
||||||
err := files_utils.EnsureDirectories([]string{
|
err := files_utils.EnsureDirectories([]string{
|
||||||
config.GetEnv().TempFolder,
|
config.GetEnv().TempFolder,
|
||||||
|
filepath.Dir(tempFilePath),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to ensure directories: %w", err)
|
return fmt.Errorf("failed to ensure directories: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tempFilePath := filepath.Join(config.GetEnv().TempFolder, fileName)
|
|
||||||
logger.Debug("Creating temp file", "fileName", fileName, "tempPath", tempFilePath)
|
logger.Debug("Creating temp file", "fileName", fileName, "tempPath", tempFilePath)
|
||||||
|
|
||||||
tempFile, err := os.Create(tempFilePath)
|
tempFile, err := os.Create(tempFilePath)
|
||||||
@@ -101,6 +102,10 @@ func (l *LocalStorage) SaveFile(
|
|||||||
finalPath,
|
finalPath,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if err = files_utils.EnsureDirectories([]string{filepath.Dir(finalPath)}); err != nil {
|
||||||
|
return fmt.Errorf("failed to ensure final directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Move the file from temp to backups directory
|
// Move the file from temp to backups directory
|
||||||
if err = os.Rename(tempFilePath, finalPath); err != nil {
|
if err = os.Rename(tempFilePath, finalPath); err != nil {
|
||||||
logger.Error(
|
logger.Error(
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"databasus-backend/internal/features/audit_logs"
|
"databasus-backend/internal/features/audit_logs"
|
||||||
"databasus-backend/internal/features/backups/backups"
|
|
||||||
"databasus-backend/internal/features/backups/backups/backuping"
|
"databasus-backend/internal/features/backups/backups/backuping"
|
||||||
|
backups_services "databasus-backend/internal/features/backups/backups/services"
|
||||||
backups_config "databasus-backend/internal/features/backups/config"
|
backups_config "databasus-backend/internal/features/backups/config"
|
||||||
"databasus-backend/internal/features/databases"
|
"databasus-backend/internal/features/databases"
|
||||||
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
|
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
|
||||||
@@ -26,8 +26,8 @@ func Test_SetupDependencies_CalledTwice_LogsWarning(t *testing.T) {
|
|||||||
audit_logs.SetupDependencies()
|
audit_logs.SetupDependencies()
|
||||||
audit_logs.SetupDependencies()
|
audit_logs.SetupDependencies()
|
||||||
|
|
||||||
backups.SetupDependencies()
|
backups_services.SetupDependencies()
|
||||||
backups.SetupDependencies()
|
backups_services.SetupDependencies()
|
||||||
|
|
||||||
backups_config.SetupDependencies()
|
backups_config.SetupDependencies()
|
||||||
backups_config.SetupDependencies()
|
backups_config.SetupDependencies()
|
||||||
|
|||||||
@@ -17,8 +17,9 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"databasus-backend/internal/config"
|
"databasus-backend/internal/config"
|
||||||
"databasus-backend/internal/features/backups/backups"
|
backups_controllers "databasus-backend/internal/features/backups/backups/controllers"
|
||||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||||
|
backups_dto "databasus-backend/internal/features/backups/backups/dto"
|
||||||
backups_config "databasus-backend/internal/features/backups/config"
|
backups_config "databasus-backend/internal/features/backups/config"
|
||||||
"databasus-backend/internal/features/databases"
|
"databasus-backend/internal/features/databases"
|
||||||
pgtypes "databasus-backend/internal/features/databases/databases/postgresql"
|
pgtypes "databasus-backend/internal/features/databases/databases/postgresql"
|
||||||
@@ -1234,7 +1235,7 @@ func createTestRouter() *gin.Engine {
|
|||||||
workspaces_controllers.GetMembershipController(),
|
workspaces_controllers.GetMembershipController(),
|
||||||
databases.GetDatabaseController(),
|
databases.GetDatabaseController(),
|
||||||
backups_config.GetBackupConfigController(),
|
backups_config.GetBackupConfigController(),
|
||||||
backups.GetBackupController(),
|
backups_controllers.GetBackupController(),
|
||||||
restores.GetRestoreController(),
|
restores.GetRestoreController(),
|
||||||
)
|
)
|
||||||
return router
|
return router
|
||||||
@@ -1255,7 +1256,7 @@ func waitForBackupCompletion(
|
|||||||
t.Fatalf("Timeout waiting for backup completion after %v", timeout)
|
t.Fatalf("Timeout waiting for backup completion after %v", timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
var response backups.GetBackupsResponse
|
var response backups_dto.GetBackupsResponse
|
||||||
test_utils.MakeGetRequestAndUnmarshal(
|
test_utils.MakeGetRequestAndUnmarshal(
|
||||||
t,
|
t,
|
||||||
router,
|
router,
|
||||||
@@ -1431,7 +1432,7 @@ func createBackupViaAPI(
|
|||||||
databaseID uuid.UUID,
|
databaseID uuid.UUID,
|
||||||
token string,
|
token string,
|
||||||
) {
|
) {
|
||||||
request := backups.MakeBackupRequest{DatabaseID: databaseID}
|
request := backups_dto.MakeBackupRequest{DatabaseID: databaseID}
|
||||||
test_utils.MakePostRequest(
|
test_utils.MakePostRequest(
|
||||||
t,
|
t,
|
||||||
router,
|
router,
|
||||||
|
|||||||
125
backend/internal/util/wal/calculator.go
Normal file
125
backend/internal/util/wal/calculator.go
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
package wal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
segmentNameLen = 24
|
||||||
|
timelineLen = 8
|
||||||
|
logLen = 8
|
||||||
|
segLen = 8
|
||||||
|
defaultSegmentSizeBytes = 16 * 1024 * 1024 // 16 MB
|
||||||
|
)
|
||||||
|
|
||||||
|
// WalCalculator performs WAL segment name arithmetic for a given WAL segment size.
|
||||||
|
//
|
||||||
|
// WAL segment name format: TTTTTTTTLLLLLLLLSSSSSSSS (24 hex chars)
|
||||||
|
// - TT: timeline (8 hex digits)
|
||||||
|
// - LL: log file / XLogId (8 hex digits)
|
||||||
|
// - SS: segment within log file (8 hex digits)
|
||||||
|
//
|
||||||
|
// segmentsPerXLogId = 0x100000000 / segmentSizeBytes
|
||||||
|
// Increment SS; if SS >= segmentsPerXLogId → SS = 0, LL++
|
||||||
|
type WalCalculator struct {
|
||||||
|
segmentSizeBytes int64
|
||||||
|
segmentsPerXLogId uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWalCalculator creates a WalCalculator for the given WAL segment size in bytes.
|
||||||
|
// Pass 0 or a negative value to use the PostgreSQL default of 16 MB.
|
||||||
|
func NewWalCalculator(segmentSizeBytes int64) *WalCalculator {
|
||||||
|
if segmentSizeBytes <= 0 {
|
||||||
|
segmentSizeBytes = defaultSegmentSizeBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
return &WalCalculator{
|
||||||
|
segmentSizeBytes: segmentSizeBytes,
|
||||||
|
segmentsPerXLogId: uint64(0x100000000) / uint64(segmentSizeBytes),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NextSegment computes the next WAL segment name after current.
|
||||||
|
// Returns an error if current is not a valid 24-character hex WAL segment name.
|
||||||
|
func (c *WalCalculator) NextSegment(current string) (string, error) {
|
||||||
|
if !c.IsValidSegmentName(current) {
|
||||||
|
return "", fmt.Errorf("invalid WAL segment name: %q", current)
|
||||||
|
}
|
||||||
|
|
||||||
|
timeline := current[:timelineLen]
|
||||||
|
logHex := current[timelineLen : timelineLen+logLen]
|
||||||
|
segHex := current[timelineLen+logLen:]
|
||||||
|
|
||||||
|
logVal, err := parseHex32(logHex)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("parse log part of %q: %w", current, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
segVal, err := parseHex32(segHex)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("parse seg part of %q: %w", current, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
segVal++
|
||||||
|
if uint64(segVal) >= c.segmentsPerXLogId {
|
||||||
|
segVal = 0
|
||||||
|
logVal++
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s%08X%08X", strings.ToUpper(timeline), logVal, segVal), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValidSegmentName returns true if name is a 24-character uppercase (or lowercase) hex string
|
||||||
|
// representing a valid WAL segment name.
|
||||||
|
func (c *WalCalculator) IsValidSegmentName(name string) bool {
|
||||||
|
if len(name) != segmentNameLen {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := hex.DecodeString(name)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compare compares two WAL segment names a and b by their numeric value.
|
||||||
|
// Returns -1 if a < b, 0 if a == b, 1 if a > b.
|
||||||
|
// Both names must be valid; returns an error otherwise.
|
||||||
|
// Timeline is compared first, then log file, then segment number.
|
||||||
|
func (c *WalCalculator) Compare(a, b string) (int, error) {
|
||||||
|
if !c.IsValidSegmentName(a) {
|
||||||
|
return 0, fmt.Errorf("invalid WAL segment name: %q", a)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !c.IsValidSegmentName(b) {
|
||||||
|
return 0, fmt.Errorf("invalid WAL segment name: %q", b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fixed-width uppercase hex: lexicographic order equals numeric order.
|
||||||
|
aUpper := strings.ToUpper(a)
|
||||||
|
bUpper := strings.ToUpper(b)
|
||||||
|
|
||||||
|
if aUpper < bUpper {
|
||||||
|
return -1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if aUpper > bUpper {
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseHex32(s string) (uint32, error) {
|
||||||
|
if len(s) != 8 {
|
||||||
|
return 0, errors.New("expected 8 hex characters")
|
||||||
|
}
|
||||||
|
|
||||||
|
b, err := hex.DecodeString(s)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3]), nil
|
||||||
|
}
|
||||||
221
backend/internal/util/wal/calculator_test.go
Normal file
221
backend/internal/util/wal/calculator_test.go
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
package wal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
mb1 = 1 * 1024 * 1024
|
||||||
|
mb16 = 16 * 1024 * 1024
|
||||||
|
mb64 = 64 * 1024 * 1024
|
||||||
|
)
|
||||||
|
|
||||||
|
// NextSegment — no wrap
|
||||||
|
|
||||||
|
func Test_WalCalculator_NextSegment_DefaultSegmentSize_NoWrap(t *testing.T) {
|
||||||
|
calc := NewWalCalculator(mb16)
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"000000010000000100000000", "000000010000000100000001"},
|
||||||
|
{"000000010000000100000027", "000000010000000100000028"},
|
||||||
|
{"0000000100000001000000AB", "0000000100000001000000AC"},
|
||||||
|
{"0000000100000001000000FE", "0000000100000001000000FF"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
result, err := calc.NextSegment(tc.input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tc.expected, result, "input=%s", tc.input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NextSegment — wrap at 0x100 (256) for 16 MB segments
|
||||||
|
|
||||||
|
func Test_WalCalculator_NextSegment_DefaultSegmentSize_WrapsAt256(t *testing.T) {
|
||||||
|
calc := NewWalCalculator(mb16)
|
||||||
|
|
||||||
|
// SS=0xFF → SS=0x00, LL++
|
||||||
|
result, err := calc.NextSegment("0000000100000001000000FF")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "000000010000000200000000", result)
|
||||||
|
|
||||||
|
result, err = calc.NextSegment("0000000200000005000000FF")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "000000020000000600000000", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NextSegment — wrap at 0x1000 (4096) for 1 MB segments
|
||||||
|
|
||||||
|
func Test_WalCalculator_NextSegment_1MbSegmentSize_WrapsAt4096(t *testing.T) {
|
||||||
|
calc := NewWalCalculator(mb1)
|
||||||
|
|
||||||
|
// segmentsPerXLogId = 0x100000000 / (1*1024*1024) = 4096 = 0x1000
|
||||||
|
// Last valid SS = 0x00000FFF
|
||||||
|
|
||||||
|
result, err := calc.NextSegment("000000010000000100000FFE")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "000000010000000100000FFF", result)
|
||||||
|
|
||||||
|
// wrap: SS=0x0FFF → SS=0, LL++
|
||||||
|
result, err = calc.NextSegment("000000010000000100000FFF")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "000000010000000200000000", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NextSegment — wrap at 0x40 (64) for 64 MB segments
|
||||||
|
|
||||||
|
func Test_WalCalculator_NextSegment_64MbSegmentSize_WrapsAt64(t *testing.T) {
|
||||||
|
calc := NewWalCalculator(mb64)
|
||||||
|
|
||||||
|
// segmentsPerXLogId = 0x100000000 / (64*1024*1024) = 64 = 0x40
|
||||||
|
// Last valid SS = 0x0000003F
|
||||||
|
|
||||||
|
result, err := calc.NextSegment("00000001000000010000003E")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "00000001000000010000003F", result)
|
||||||
|
|
||||||
|
// wrap: SS=0x3F → SS=0, LL++
|
||||||
|
result, err = calc.NextSegment("00000001000000010000003F")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "000000010000000200000000", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NextSegment — log file increment on wrap
|
||||||
|
|
||||||
|
func Test_WalCalculator_NextSegment_IncrementsLog_OnSegmentWrap(t *testing.T) {
|
||||||
|
calc := NewWalCalculator(mb16)
|
||||||
|
|
||||||
|
// LL=0x00000001, wraps to LL=0x00000002
|
||||||
|
result, err := calc.NextSegment("0000000100000001000000FF")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "000000010000000200000000", result)
|
||||||
|
|
||||||
|
// LL=0x0000000F, wraps to LL=0x00000010
|
||||||
|
result, err = calc.NextSegment("00000001000000FF000000FF")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "000000010000010000000000", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NextSegment — timeline is preserved
|
||||||
|
|
||||||
|
func Test_WalCalculator_NextSegment_TimelinePreserved(t *testing.T) {
|
||||||
|
calc := NewWalCalculator(mb16)
|
||||||
|
|
||||||
|
result, err := calc.NextSegment("000000030000000100000005")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "000000030000000100000006", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NextSegment — invalid input
|
||||||
|
|
||||||
|
func Test_WalCalculator_NextSegment_InvalidName_ReturnsError(t *testing.T) {
|
||||||
|
calc := NewWalCalculator(mb16)
|
||||||
|
|
||||||
|
cases := []string{
|
||||||
|
"",
|
||||||
|
"00000001000000010000000", // 23 chars
|
||||||
|
"0000000100000001000000001", // 25 chars
|
||||||
|
"00000001000000010000000G", // non-hex char
|
||||||
|
"short",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, name := range cases {
|
||||||
|
_, err := calc.NextSegment(name)
|
||||||
|
assert.Error(t, err, "expected error for input %q", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValidSegmentName
|
||||||
|
|
||||||
|
func Test_WalCalculator_IsValidSegmentName_ValidName_ReturnsTrue(t *testing.T) {
|
||||||
|
calc := NewWalCalculator(mb16)
|
||||||
|
|
||||||
|
valid := []string{
|
||||||
|
"000000010000000100000000",
|
||||||
|
"0000000100000001000000FF",
|
||||||
|
"FFFFFFFFFFFFFFFFFFFFFFFF",
|
||||||
|
"000000010000000100000027",
|
||||||
|
"0000000200000005000000AB",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, name := range valid {
|
||||||
|
assert.True(t, calc.IsValidSegmentName(name), "expected valid for %q", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_WalCalculator_IsValidSegmentName_TooShort_ReturnsFalse(t *testing.T) {
|
||||||
|
calc := NewWalCalculator(mb16)
|
||||||
|
|
||||||
|
assert.False(t, calc.IsValidSegmentName("00000001000000010000000")) // 23 chars
|
||||||
|
assert.False(t, calc.IsValidSegmentName(""))
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_WalCalculator_IsValidSegmentName_TooLong_ReturnsFalse(t *testing.T) {
|
||||||
|
calc := NewWalCalculator(mb16)
|
||||||
|
|
||||||
|
assert.False(t, calc.IsValidSegmentName("0000000100000001000000001")) // 25 chars
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_WalCalculator_IsValidSegmentName_NonHex_ReturnsFalse(t *testing.T) {
|
||||||
|
calc := NewWalCalculator(mb16)
|
||||||
|
|
||||||
|
assert.False(t, calc.IsValidSegmentName("00000001000000010000000G"))
|
||||||
|
assert.False(t, calc.IsValidSegmentName("00000001000000010000000Z"))
|
||||||
|
assert.False(t, calc.IsValidSegmentName("000000010000000100000 00"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compare
|
||||||
|
|
||||||
|
func Test_WalCalculator_Compare_ReturnsCorrectOrdering(t *testing.T) {
|
||||||
|
calc := NewWalCalculator(mb16)
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
a string
|
||||||
|
b string
|
||||||
|
expected int
|
||||||
|
}{
|
||||||
|
// equal
|
||||||
|
{"000000010000000100000001", "000000010000000100000001", 0},
|
||||||
|
// segment ordering
|
||||||
|
{"000000010000000100000001", "000000010000000100000002", -1},
|
||||||
|
{"000000010000000100000002", "000000010000000100000001", 1},
|
||||||
|
// log ordering
|
||||||
|
{"000000010000000100000000", "000000010000000200000000", -1},
|
||||||
|
{"000000010000000200000000", "000000010000000100000000", 1},
|
||||||
|
// timeline ordering
|
||||||
|
{"000000010000000100000000", "000000020000000100000000", -1},
|
||||||
|
{"000000020000000100000000", "000000010000000100000000", 1},
|
||||||
|
// across wrap boundary: log 1, seg 255 < log 2, seg 0
|
||||||
|
{"0000000100000001000000FF", "000000010000000200000000", -1},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
result, err := calc.Compare(tc.a, tc.b)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tc.expected, result, "Compare(%s, %s)", tc.a, tc.b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_WalCalculator_Compare_InvalidInput_ReturnsError(t *testing.T) {
|
||||||
|
calc := NewWalCalculator(mb16)
|
||||||
|
|
||||||
|
_, err := calc.Compare("invalid", "000000010000000100000001")
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
_, err = calc.Compare("000000010000000100000001", "invalid")
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default segment size via zero input
|
||||||
|
|
||||||
|
func Test_WalCalculator_NewWalCalculator_ZeroSize_UsesDefault16MB(t *testing.T) {
|
||||||
|
calc := NewWalCalculator(0)
|
||||||
|
assert.Equal(t, int64(mb16), calc.segmentSizeBytes)
|
||||||
|
assert.Equal(t, uint64(256), calc.segmentsPerXLogId)
|
||||||
|
}
|
||||||
72
backend/migrations/20260306045548_add_wal_properties.sql
Normal file
72
backend/migrations/20260306045548_add_wal_properties.sql
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
-- +goose Up
|
||||||
|
-- +goose StatementBegin
|
||||||
|
ALTER TABLE databases
|
||||||
|
ADD COLUMN agent_token TEXT,
|
||||||
|
ADD COLUMN is_agent_token_generated BOOLEAN NOT NULL DEFAULT FALSE;
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX idx_databases_agent_token ON databases (agent_token) WHERE agent_token IS NOT NULL;
|
||||||
|
|
||||||
|
ALTER TABLE postgresql_databases
|
||||||
|
ADD COLUMN backup_type TEXT NOT NULL DEFAULT 'PG_DUMP';
|
||||||
|
|
||||||
|
ALTER TABLE postgresql_databases
|
||||||
|
ALTER COLUMN host DROP NOT NULL,
|
||||||
|
ALTER COLUMN port DROP NOT NULL,
|
||||||
|
ALTER COLUMN username DROP NOT NULL,
|
||||||
|
ALTER COLUMN password DROP NOT NULL;
|
||||||
|
|
||||||
|
ALTER TABLE backups
|
||||||
|
ADD COLUMN pg_wal_backup_type TEXT,
|
||||||
|
ADD COLUMN pg_wal_start_segment TEXT,
|
||||||
|
ADD COLUMN pg_wal_stop_segment TEXT,
|
||||||
|
ADD COLUMN pg_version TEXT,
|
||||||
|
ADD COLUMN pg_wal_segment_name TEXT;
|
||||||
|
|
||||||
|
CREATE INDEX idx_backups_pg_wal_segment_name
|
||||||
|
ON backups (database_id, pg_wal_segment_name)
|
||||||
|
WHERE pg_wal_segment_name IS NOT NULL;
|
||||||
|
|
||||||
|
CREATE INDEX idx_backups_pg_wal_backup_type_created
|
||||||
|
ON backups (database_id, pg_wal_backup_type, created_at DESC)
|
||||||
|
WHERE pg_wal_backup_type IS NOT NULL;
|
||||||
|
-- +goose StatementEnd
|
||||||
|
|
||||||
|
-- +goose Down
|
||||||
|
-- +goose StatementBegin
|
||||||
|
DROP INDEX IF EXISTS idx_backups_pg_wal_segment_name;
|
||||||
|
DROP INDEX IF EXISTS idx_backups_pg_wal_backup_type_created;
|
||||||
|
|
||||||
|
ALTER TABLE backups
|
||||||
|
DROP COLUMN pg_wal_backup_type,
|
||||||
|
DROP COLUMN pg_wal_start_segment,
|
||||||
|
DROP COLUMN pg_wal_stop_segment,
|
||||||
|
DROP COLUMN pg_version,
|
||||||
|
DROP COLUMN pg_wal_segment_name;
|
||||||
|
|
||||||
|
UPDATE postgresql_databases
|
||||||
|
SET host = 'localhost' WHERE host IS NULL OR host = '';
|
||||||
|
|
||||||
|
UPDATE postgresql_databases
|
||||||
|
SET port = 5432 WHERE port IS NULL OR port = 0;
|
||||||
|
|
||||||
|
UPDATE postgresql_databases
|
||||||
|
SET username = 'postgres' WHERE username IS NULL OR username = '';
|
||||||
|
|
||||||
|
UPDATE postgresql_databases
|
||||||
|
SET password = 'stubpassword' WHERE password IS NULL OR password = '';
|
||||||
|
|
||||||
|
ALTER TABLE postgresql_databases
|
||||||
|
DROP COLUMN backup_type;
|
||||||
|
|
||||||
|
ALTER TABLE postgresql_databases
|
||||||
|
ALTER COLUMN host SET NOT NULL,
|
||||||
|
ALTER COLUMN port SET NOT NULL,
|
||||||
|
ALTER COLUMN username SET NOT NULL,
|
||||||
|
ALTER COLUMN password SET NOT NULL;
|
||||||
|
|
||||||
|
DROP INDEX IF EXISTS idx_databases_agent_token;
|
||||||
|
|
||||||
|
ALTER TABLE databases
|
||||||
|
DROP COLUMN agent_token,
|
||||||
|
DROP COLUMN is_agent_token_generated;
|
||||||
|
-- +goose StatementEnd
|
||||||
Reference in New Issue
Block a user