Compare commits

...

14 Commits

Author SHA1 Message Date
Rostislav Dugin
a0f02b253e Merge pull request #427 from databasus/develop
FIX (retention): Fix GFS retention while hourly backups prevent daily…
2026-03-11 15:36:27 +03:00
Rostislav Dugin
812f11bc2f FIX (retention): Fix GFS retention while hourly backups prevent daily from cleanup 2026-03-11 15:35:53 +03:00
Rostislav Dugin
e796e3ddf0 Merge pull request #426 from databasus/develop
FIX (mysql): Detect supported compression levels
2026-03-11 12:53:35 +03:00
Rostislav Dugin
c96d3db337 FIX (mysql): Detect supported compression levels 2026-03-11 12:52:41 +03:00
Rostislav Dugin
ed6c3a2034 Merge pull request #425 from databasus/develop
Develop
2026-03-11 12:31:19 +03:00
Rostislav Dugin
05115047c3 FEATURE (version): Reload frontend if faced version mismatch with backend 2026-03-11 12:28:07 +03:00
Rostislav Dugin
446b96c6c0 FEATURE (arch): Add architecture to Databasus version in the bottom left of UI 2026-03-11 11:39:53 +03:00
Rostislav Dugin
36a0448da1 Merge pull request #420 from databasus/develop
FEATURE (email): Add skipping TLS for email notifier
2026-03-08 22:53:45 +03:00
Rostislav Dugin
8e392cfeab FEATURE (email): Add skipping TLS for email notifier 2026-03-08 22:48:28 +03:00
Rostislav Dugin
6683db1e52 Merge pull request #419 from databasus/develop
FIX (issues): Add DB version to issues template
2026-03-08 22:22:52 +03:00
Rostislav Dugin
703b883936 FIX (issues): Add DB version to issues template 2026-03-08 22:22:26 +03:00
Rostislav Dugin
e818bcff82 Merge pull request #415 from databasus/develop
Develop
2026-03-06 09:45:11 +03:00
Rostislav Dugin
b2f98f1332 FIX (mysql\mariadb): Increase max allowed packet size over restore for MySQL\MariaDB 2026-03-06 09:44:17 +03:00
Rostislav Dugin
230cc27ea6 FEATURE (backups): Add WAL API 2026-03-06 08:10:29 +03:00
75 changed files with 5504 additions and 954 deletions

View File

@@ -4,7 +4,9 @@ about: Report a bug or unexpected behavior in Databasus
labels: bug
---
## Databasus version
## Databasus version (screenshot)
It is displayed in the bottom left corner of the Databasus UI. Please attach screenshot, not just version text
<!-- e.g. 1.4.2 -->
@@ -12,6 +14,10 @@ labels: bug
<!-- e.g. Ubuntu 22.04 x64, macOS 14 ARM, Windows 11 x64 -->
## Database type and version (optional, for DB-related bugs)
<!-- e.g. PostgreSQL 16 in Docker, MySQL 8.0 installed on server, MariaDB 11.4 in AWS Cloud -->
## Describe the bug (please write manually, do not ask AI to summarize)
**What happened:**

1
.gitignore vendored
View File

@@ -12,3 +12,4 @@ node_modules/
.DS_Store
/scripts
.vscode/settings.json
.claude

1
CLAUDE.md Normal file
View File

@@ -0,0 +1 @@
Look at @AGENTS.md

View File

@@ -71,8 +71,10 @@ FROM debian:bookworm-slim
# Add version metadata to runtime image
ARG APP_VERSION=dev
ARG TARGETARCH
LABEL org.opencontainers.image.version=$APP_VERSION
ENV APP_VERSION=$APP_VERSION
ENV CONTAINER_ARCH=$TARGETARCH
# Set production mode for Docker containers
ENV ENV_MODE=production
@@ -269,7 +271,8 @@ window.__RUNTIME_CONFIG__ = {
GITHUB_CLIENT_ID: '\${GITHUB_CLIENT_ID:-}',
GOOGLE_CLIENT_ID: '\${GOOGLE_CLIENT_ID:-}',
IS_EMAIL_CONFIGURED: '\$IS_EMAIL_CONFIGURED',
CLOUDFLARE_TURNSTILE_SITE_KEY: '\${CLOUDFLARE_TURNSTILE_SITE_KEY:-}'
CLOUDFLARE_TURNSTILE_SITE_KEY: '\${CLOUDFLARE_TURNSTILE_SITE_KEY:-}',
CONTAINER_ARCH: '\${CONTAINER_ARCH:-unknown}'
};
JSEOF

View File

@@ -14,9 +14,10 @@ import (
"databasus-backend/internal/config"
"databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
"databasus-backend/internal/features/backups/backups/backuping"
backups_controllers "databasus-backend/internal/features/backups/backups/controllers"
backups_download "databasus-backend/internal/features/backups/backups/download"
backups_services "databasus-backend/internal/features/backups/backups/services"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/disk"
@@ -28,6 +29,7 @@ import (
"databasus-backend/internal/features/restores/restoring"
"databasus-backend/internal/features/storages"
system_healthcheck "databasus-backend/internal/features/system/healthcheck"
system_version "databasus-backend/internal/features/system/version"
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
users_controllers "databasus-backend/internal/features/users/controllers"
users_middleware "databasus-backend/internal/features/users/middleware"
@@ -209,7 +211,10 @@ func setUpRoutes(r *gin.Engine) {
userController := users_controllers.GetUserController()
userController.RegisterRoutes(v1)
system_healthcheck.GetHealthcheckController().RegisterRoutes(v1)
backups.GetBackupController().RegisterPublicRoutes(v1)
system_version.GetVersionController().RegisterRoutes(v1)
backups_controllers.GetBackupController().RegisterPublicRoutes(v1)
backups_controllers.GetPostgresWalBackupController().RegisterRoutes(v1)
databases.GetDatabaseController().RegisterPublicRoutes(v1)
// Setup auth middleware
userService := users_services.GetUserService()
@@ -226,7 +231,7 @@ func setUpRoutes(r *gin.Engine) {
notifiers.GetNotifierController().RegisterRoutes(protected)
storages.GetStorageController().RegisterRoutes(protected)
databases.GetDatabaseController().RegisterRoutes(protected)
backups.GetBackupController().RegisterRoutes(protected)
backups_controllers.GetBackupController().RegisterRoutes(protected)
restores.GetRestoreController().RegisterRoutes(protected)
healthcheck_config.GetHealthcheckConfigController().RegisterRoutes(protected)
healthcheck_attempt.GetHealthcheckAttemptController().RegisterRoutes(protected)
@@ -238,7 +243,7 @@ func setUpRoutes(r *gin.Engine) {
func setUpDependencies() {
databases.SetupDependencies()
backups.SetupDependencies()
backups_services.SetupDependencies()
restores.SetupDependencies()
healthcheck_config.SetupDependencies()
audit_logs.SetupDependencies()

View File

@@ -80,8 +80,7 @@ func (c *BackupCleaner) DeleteBackup(backup *backups_core.Backup) error {
return err
}
err = storage.DeleteFile(c.fieldEncryptor, backup.FileName)
if err != nil {
if err := storage.DeleteFile(c.fieldEncryptor, backup.FileName); err != nil {
// we do not return error here, because sometimes clean up performed
// before unavailable storage removal or change - therefore we should
// proceed even in case of error. It's possible that some S3 or
@@ -408,6 +407,10 @@ func buildGFSKeepSet(
) map[uuid.UUID]bool {
keep := make(map[uuid.UUID]bool)
if len(backups) == 0 {
return keep
}
hoursSeen := make(map[string]bool)
daysSeen := make(map[string]bool)
weeksSeen := make(map[string]bool)
@@ -416,6 +419,52 @@ func buildGFSKeepSet(
hoursKept, daysKept, weeksKept, monthsKept, yearsKept := 0, 0, 0, 0, 0
// Compute per-level time-window cutoffs so higher-frequency slots
// cannot absorb backups that belong to lower-frequency levels.
ref := backups[0].CreatedAt
rawHourlyCutoff := ref.Add(-time.Duration(hours) * time.Hour)
rawDailyCutoff := ref.Add(-time.Duration(days) * 24 * time.Hour)
rawWeeklyCutoff := ref.Add(-time.Duration(weeks) * 7 * 24 * time.Hour)
rawMonthlyCutoff := ref.AddDate(0, -months, 0)
rawYearlyCutoff := ref.AddDate(-years, 0, 0)
// Hierarchical capping: each level's window cannot extend further back
// than the nearest active lower-frequency level's window.
yearlyCutoff := rawYearlyCutoff
monthlyCutoff := rawMonthlyCutoff
if years > 0 {
monthlyCutoff = laterOf(monthlyCutoff, yearlyCutoff)
}
weeklyCutoff := rawWeeklyCutoff
if months > 0 {
weeklyCutoff = laterOf(weeklyCutoff, monthlyCutoff)
} else if years > 0 {
weeklyCutoff = laterOf(weeklyCutoff, yearlyCutoff)
}
dailyCutoff := rawDailyCutoff
if weeks > 0 {
dailyCutoff = laterOf(dailyCutoff, weeklyCutoff)
} else if months > 0 {
dailyCutoff = laterOf(dailyCutoff, monthlyCutoff)
} else if years > 0 {
dailyCutoff = laterOf(dailyCutoff, yearlyCutoff)
}
hourlyCutoff := rawHourlyCutoff
if days > 0 {
hourlyCutoff = laterOf(hourlyCutoff, dailyCutoff)
} else if weeks > 0 {
hourlyCutoff = laterOf(hourlyCutoff, weeklyCutoff)
} else if months > 0 {
hourlyCutoff = laterOf(hourlyCutoff, monthlyCutoff)
} else if years > 0 {
hourlyCutoff = laterOf(hourlyCutoff, yearlyCutoff)
}
for _, backup := range backups {
t := backup.CreatedAt
@@ -426,31 +475,31 @@ func buildGFSKeepSet(
monthKey := t.Format("2006-01")
yearKey := t.Format("2006")
if hours > 0 && hoursKept < hours && !hoursSeen[hourKey] {
if hours > 0 && hoursKept < hours && !hoursSeen[hourKey] && t.After(hourlyCutoff) {
keep[backup.ID] = true
hoursSeen[hourKey] = true
hoursKept++
}
if days > 0 && daysKept < days && !daysSeen[dayKey] {
if days > 0 && daysKept < days && !daysSeen[dayKey] && t.After(dailyCutoff) {
keep[backup.ID] = true
daysSeen[dayKey] = true
daysKept++
}
if weeks > 0 && weeksKept < weeks && !weeksSeen[weekKey] {
if weeks > 0 && weeksKept < weeks && !weeksSeen[weekKey] && t.After(weeklyCutoff) {
keep[backup.ID] = true
weeksSeen[weekKey] = true
weeksKept++
}
if months > 0 && monthsKept < months && !monthsSeen[monthKey] {
if months > 0 && monthsKept < months && !monthsSeen[monthKey] && t.After(monthlyCutoff) {
keep[backup.ID] = true
monthsSeen[monthKey] = true
monthsKept++
}
if years > 0 && yearsKept < years && !yearsSeen[yearKey] {
if years > 0 && yearsKept < years && !yearsSeen[yearKey] && t.After(yearlyCutoff) {
keep[backup.ID] = true
yearsSeen[yearKey] = true
yearsKept++
@@ -459,3 +508,11 @@ func buildGFSKeepSet(
return keep
}
func laterOf(a, b time.Time) time.Time {
if a.After(b) {
return a
}
return b
}

File diff suppressed because it is too large Load Diff

View File

@@ -697,160 +697,6 @@ func Test_CleanByCount_DoesNotDeleteInProgressBackups(t *testing.T) {
assert.True(t, inProgressFound, "In-progress backup should not be deleted by count policy")
}
func Test_CleanByGFS_KeepsCorrectBackupsPerSlot(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeGFS,
RetentionGfsDays: 3,
RetentionGfsWeeks: 0,
RetentionGfsMonths: 0,
RetentionGfsYears: 0,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
now := time.Now().UTC()
// Create 5 backups on 5 different days; only the 3 newest days should be kept
var backupIDs []uuid.UUID
for i := 0; i < 5; i++ {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10,
CreatedAt: now.Add(-time.Duration(4-i) * 24 * time.Hour).Truncate(24 * time.Hour),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
backupIDs = append(backupIDs, backup.ID)
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Equal(t, 3, len(remainingBackups))
remainingIDs := make(map[uuid.UUID]bool)
for _, backup := range remainingBackups {
remainingIDs[backup.ID] = true
}
assert.False(t, remainingIDs[backupIDs[0]], "Oldest daily backup should be deleted")
assert.False(t, remainingIDs[backupIDs[1]], "2nd oldest daily backup should be deleted")
assert.True(t, remainingIDs[backupIDs[2]], "3rd backup should remain")
assert.True(t, remainingIDs[backupIDs[3]], "4th backup should remain")
assert.True(t, remainingIDs[backupIDs[4]], "Newest backup should remain")
}
func Test_CleanByGFS_WithWeeklyAndMonthlySlots_KeepsWiderSpread(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeGFS,
RetentionGfsDays: 2,
RetentionGfsWeeks: 2,
RetentionGfsMonths: 1,
RetentionGfsYears: 0,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
now := time.Now().UTC()
// Create one backup per week for 6 weeks (each on Monday of that week)
// GFS should keep: 2 daily (most recent 2 unique days) + 2 weekly + 1 monthly = up to 5 unique
var createdIDs []uuid.UUID
for i := 0; i < 6; i++ {
weekOffset := time.Duration(5-i) * 7 * 24 * time.Hour
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10,
CreatedAt: now.Add(-weekOffset).Truncate(24 * time.Hour),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
createdIDs = append(createdIDs, backup.ID)
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
// We should have at most 5 backups kept (2 daily + 2 weekly + 1 monthly, but with overlap possible)
// The exact count depends on how many unique periods are covered
assert.LessOrEqual(t, len(remainingBackups), 5)
assert.GreaterOrEqual(t, len(remainingBackups), 1)
// The two most recent backups should always be retained (daily slots)
remainingIDs := make(map[uuid.UUID]bool)
for _, backup := range remainingBackups {
remainingIDs[backup.ID] = true
}
assert.True(t, remainingIDs[createdIDs[4]], "Second newest backup should be retained (daily)")
assert.True(t, remainingIDs[createdIDs[5]], "Newest backup should be retained (daily)")
}
// Test_DeleteBackup_WhenStorageDeleteFails_BackupStillRemovedFromDatabase verifies resilience
// when storage becomes unavailable. Even if storage.DeleteFile fails (e.g., storage is offline,
// credentials changed, or storage was deleted), the backup record should still be removed from
@@ -897,292 +743,6 @@ func Test_DeleteBackup_WhenStorageDeleteFails_BackupStillRemovedFromDatabase(t *
assert.Nil(t, deletedBackup)
}
func Test_CleanByGFS_WithHourlySlots_KeepsCorrectBackups(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
testStorage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, testStorage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(testStorage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
interval := createTestInterval()
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeGFS,
RetentionGfsHours: 3,
StorageID: &testStorage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
now := time.Now().UTC()
// Create 5 backups spaced 1 hour apart; only the 3 newest hours should be kept
var backupIDs []uuid.UUID
for i := 0; i < 5; i++ {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: testStorage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10,
CreatedAt: now.Add(-time.Duration(4-i) * time.Hour).Truncate(time.Hour),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
backupIDs = append(backupIDs, backup.ID)
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Equal(t, 3, len(remainingBackups))
remainingIDs := make(map[uuid.UUID]bool)
for _, backup := range remainingBackups {
remainingIDs[backup.ID] = true
}
assert.False(t, remainingIDs[backupIDs[0]], "Oldest hourly backup should be deleted")
assert.False(t, remainingIDs[backupIDs[1]], "2nd oldest hourly backup should be deleted")
assert.True(t, remainingIDs[backupIDs[2]], "3rd backup should remain")
assert.True(t, remainingIDs[backupIDs[3]], "4th backup should remain")
assert.True(t, remainingIDs[backupIDs[4]], "Newest backup should remain")
}
func Test_BuildGFSKeepSet(t *testing.T) {
// Fixed reference time: a Wednesday mid-month to avoid boundary edge cases in the default tests.
// Use time.Date for determinism across test runs.
ref := time.Date(2025, 6, 18, 12, 0, 0, 0, time.UTC) // Wednesday, 2025-06-18
day := 24 * time.Hour
week := 7 * day
newBackup := func(createdAt time.Time) *backups_core.Backup {
return &backups_core.Backup{ID: uuid.New(), CreatedAt: createdAt}
}
// backupsEveryDay returns n backups, newest-first, each 1 day apart.
backupsEveryDay := func(n int) []*backups_core.Backup {
bs := make([]*backups_core.Backup, n)
for i := 0; i < n; i++ {
bs[i] = newBackup(ref.Add(-time.Duration(i) * day))
}
return bs
}
// backupsEveryWeek returns n backups, newest-first, each 7 days apart.
backupsEveryWeek := func(n int) []*backups_core.Backup {
bs := make([]*backups_core.Backup, n)
for i := 0; i < n; i++ {
bs[i] = newBackup(ref.Add(-time.Duration(i) * week))
}
return bs
}
hour := time.Hour
// backupsEveryHour returns n backups, newest-first, each 1 hour apart.
backupsEveryHour := func(n int) []*backups_core.Backup {
bs := make([]*backups_core.Backup, n)
for i := 0; i < n; i++ {
bs[i] = newBackup(ref.Add(-time.Duration(i) * hour))
}
return bs
}
tests := []struct {
name string
backups []*backups_core.Backup
hours int
days int
weeks int
months int
years int
keptIndices []int // which indices in backups should be kept
deletedRange *[2]int // optional: all indices in [from, to) must be deleted
}{
{
name: "OnlyHourlySlots_KeepsNewest3Of5",
backups: backupsEveryHour(5),
hours: 3,
keptIndices: []int{0, 1, 2},
},
{
name: "SameHourDedup_OnlyNewestKeptForHourlySlot",
backups: []*backups_core.Backup{
newBackup(ref.Truncate(hour).Add(45 * time.Minute)),
newBackup(ref.Truncate(hour).Add(10 * time.Minute)),
},
hours: 1,
keptIndices: []int{0},
},
{
name: "OnlyDailySlots_KeepsNewest3Of5",
backups: backupsEveryDay(5),
days: 3,
keptIndices: []int{0, 1, 2},
},
{
name: "OnlyDailySlots_FewerBackupsThanSlots_KeepsAll",
backups: backupsEveryDay(2),
days: 5,
keptIndices: []int{0, 1},
},
{
name: "OnlyWeeklySlots_KeepsNewest2Weeks",
backups: backupsEveryWeek(4),
weeks: 2,
keptIndices: []int{0, 1},
},
{
name: "OnlyMonthlySlots_KeepsNewest2Months",
backups: []*backups_core.Backup{
newBackup(time.Date(2025, 6, 1, 12, 0, 0, 0, time.UTC)),
newBackup(time.Date(2025, 5, 1, 12, 0, 0, 0, time.UTC)),
newBackup(time.Date(2025, 4, 1, 12, 0, 0, 0, time.UTC)),
},
months: 2,
keptIndices: []int{0, 1},
},
{
name: "OnlyYearlySlots_KeepsNewest2Years",
backups: []*backups_core.Backup{
newBackup(time.Date(2025, 6, 1, 12, 0, 0, 0, time.UTC)),
newBackup(time.Date(2024, 6, 1, 12, 0, 0, 0, time.UTC)),
newBackup(time.Date(2023, 6, 1, 12, 0, 0, 0, time.UTC)),
},
years: 2,
keptIndices: []int{0, 1},
},
{
name: "SameDayDedup_OnlyNewestKeptForDailySlot",
backups: []*backups_core.Backup{
// Two backups on the same day; newest-first order
newBackup(ref.Truncate(day).Add(10 * time.Hour)),
newBackup(ref.Truncate(day).Add(2 * time.Hour)),
},
days: 1,
keptIndices: []int{0},
},
{
name: "SameWeekDedup_OnlyNewestKeptForWeeklySlot",
backups: []*backups_core.Backup{
// ref is Wednesday; add Thursday of same week
newBackup(ref.Add(1 * day)), // Thursday same week
newBackup(ref), // Wednesday same week
},
weeks: 1,
keptIndices: []int{0},
},
{
name: "AdditiveSlots_NewestFillsDailyAndWeeklyAndMonthly",
// Newest backup fills daily + weekly + monthly simultaneously
backups: []*backups_core.Backup{
newBackup(time.Date(2025, 6, 18, 12, 0, 0, 0, time.UTC)), // newest
newBackup(time.Date(2025, 6, 11, 12, 0, 0, 0, time.UTC)), // 1 week ago
newBackup(time.Date(2025, 5, 18, 12, 0, 0, 0, time.UTC)), // 1 month ago
newBackup(time.Date(2025, 4, 18, 12, 0, 0, 0, time.UTC)), // 2 months ago
},
days: 1,
weeks: 2,
months: 2,
keptIndices: []int{0, 1, 2},
},
{
name: "YearBoundary_CorrectlySplitsAcrossYears",
backups: []*backups_core.Backup{
newBackup(time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)),
newBackup(time.Date(2024, 12, 31, 12, 0, 0, 0, time.UTC)),
newBackup(time.Date(2024, 6, 1, 12, 0, 0, 0, time.UTC)),
newBackup(time.Date(2023, 6, 1, 12, 0, 0, 0, time.UTC)),
},
years: 2,
keptIndices: []int{0, 1}, // 2025 and 2024 kept; 2024-06 and 2023 deleted
},
{
name: "ISOWeekBoundary_Jan1UsesCorrectISOWeek",
// 2025-01-01 is ISO week 1 of 2025; 2024-12-28 is ISO week 52 of 2024
backups: []*backups_core.Backup{
newBackup(time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)), // ISO week 2025-W01
newBackup(time.Date(2024, 12, 28, 12, 0, 0, 0, time.UTC)), // ISO week 2024-W52
},
weeks: 2,
keptIndices: []int{0, 1}, // different ISO weeks → both kept
},
{
name: "EmptyBackups_ReturnsEmptyKeepSet",
backups: []*backups_core.Backup{},
hours: 3,
days: 3,
weeks: 2,
months: 1,
years: 1,
keptIndices: []int{},
},
{
name: "AllZeroSlots_KeepsNothing",
backups: backupsEveryDay(5),
hours: 0,
days: 0,
weeks: 0,
months: 0,
years: 0,
keptIndices: []int{},
},
{
name: "AllSlotsActive_FullCombination",
backups: backupsEveryWeek(12),
days: 2,
weeks: 3,
months: 2,
years: 1,
// 2 daily (indices 0,1) + 3rd weekly slot (index 2) + 2nd monthly slot (index 3 or later).
// Additive slots: newest fills daily+weekly+monthly+yearly; each subsequent week fills another weekly,
// and a backup ~4 weeks later fills the 2nd monthly slot.
keptIndices: []int{0, 1, 2, 3},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
keepSet := buildGFSKeepSet(tc.backups, tc.hours, tc.days, tc.weeks, tc.months, tc.years)
keptIndexSet := make(map[int]bool, len(tc.keptIndices))
for _, idx := range tc.keptIndices {
keptIndexSet[idx] = true
}
for i, backup := range tc.backups {
if keptIndexSet[i] {
assert.True(t, keepSet[backup.ID], "backup at index %d should be kept", i)
} else {
assert.False(t, keepSet[backup.ID], "backup at index %d should be deleted", i)
}
}
})
}
}
func Test_CleanByTimePeriod_SkipsRecentBackup_EvenIfOlderThanRetention(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
@@ -1354,114 +914,6 @@ func Test_CleanByCount_SkipsRecentBackup_EvenIfOverLimit(t *testing.T) {
assert.True(t, remainingIDs[newestBackup.ID], "Newest backup should be preserved")
}
func Test_CleanByGFS_SkipsRecentBackup_WhenNotInKeepSet(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
interval := createTestInterval()
// Keep only 1 daily slot. We create 2 old backups plus two recent backups on today.
// Backups are ordered newest-first, so the 15-min-old backup fills the single daily slot.
// The 30-min-old backup is the same day → not in the GFS keep-set, but it is still recent
// (within grace period) and must be preserved.
backupConfig := &backups_config.BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: backups_config.RetentionPolicyTypeGFS,
RetentionGfsDays: 1,
StorageID: &storage.ID,
BackupIntervalID: interval.ID,
BackupInterval: interval,
}
_, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
now := time.Now().UTC()
oldBackup1 := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10,
CreatedAt: now.Add(-3 * 24 * time.Hour).Truncate(24 * time.Hour),
}
oldBackup2 := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10,
CreatedAt: now.Add(-2 * 24 * time.Hour).Truncate(24 * time.Hour),
}
// Newest backup today — will fill the single GFS daily slot.
newestTodayBackup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10,
CreatedAt: now.Add(-15 * time.Minute),
}
// Slightly older backup, also today — NOT in GFS keep-set (duplicate day),
// but within the 60-minute grace period so it must survive.
recentNotInKeepSet := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10,
CreatedAt: now.Add(-30 * time.Minute),
}
for _, b := range []*backups_core.Backup{oldBackup1, oldBackup2, newestTodayBackup, recentNotInKeepSet} {
err = backupRepository.Save(b)
assert.NoError(t, err)
}
cleaner := GetBackupCleaner()
err = cleaner.cleanByRetentionPolicy()
assert.NoError(t, err)
remainingBackups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
remainingIDs := make(map[uuid.UUID]bool)
for _, backup := range remainingBackups {
remainingIDs[backup.ID] = true
}
assert.False(t, remainingIDs[oldBackup1.ID], "Old backup 1 should be deleted by GFS")
assert.False(t, remainingIDs[oldBackup2.ID], "Old backup 2 should be deleted by GFS")
assert.True(
t,
remainingIDs[newestTodayBackup.ID],
"Newest backup fills GFS daily slot and must remain",
)
assert.True(
t,
remainingIDs[recentNotInKeepSet.ID],
"Recent backup not in keep-set must be preserved by grace period",
)
}
func Test_CleanExceededBackups_SkipsRecentBackup_WhenOverTotalSizeLimit(t *testing.T) {
router := CreateTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)

View File

@@ -15,7 +15,6 @@ import (
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
files_utils "databasus-backend/internal/util/files"
)
const (
@@ -171,13 +170,7 @@ func (s *BackupsScheduler) StartBackup(database *databases.Database, isCallNotif
timestamp := time.Now().UTC()
backup := &backups_core.Backup{
ID: backupID,
FileName: fmt.Sprintf(
"%s-%s-%s",
files_utils.SanitizeFilename(database.Name),
timestamp.Format("20060102-150405"),
backupID.String(),
),
ID: backupID,
DatabaseID: backupConfig.DatabaseID,
StorageID: *backupConfig.StorageID,
Status: backups_core.BackupStatusInProgress,
@@ -185,6 +178,8 @@ func (s *BackupsScheduler) StartBackup(database *databases.Database, isCallNotif
CreatedAt: timestamp,
}
backup.GenerateFilename(database.Name)
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error(
"Failed to save backup",

View File

@@ -1,9 +1,11 @@
package backups
package backups_controllers
import (
"context"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_download "databasus-backend/internal/features/backups/backups/download"
backups_dto "databasus-backend/internal/features/backups/backups/dto"
backups_services "databasus-backend/internal/features/backups/backups/services"
"databasus-backend/internal/features/databases"
users_middleware "databasus-backend/internal/features/users/middleware"
files_utils "databasus-backend/internal/util/files"
@@ -17,7 +19,7 @@ import (
)
type BackupController struct {
backupService *BackupService
backupService *backups_services.BackupService
}
func (c *BackupController) RegisterRoutes(router *gin.RouterGroup) {
@@ -42,7 +44,7 @@ func (c *BackupController) RegisterPublicRoutes(router *gin.RouterGroup) {
// @Param database_id query string true "Database ID"
// @Param limit query int false "Number of items per page" default(10)
// @Param offset query int false "Offset for pagination" default(0)
// @Success 200 {object} GetBackupsResponse
// @Success 200 {object} backups_dto.GetBackupsResponse
// @Failure 400
// @Failure 401
// @Failure 500
@@ -54,7 +56,7 @@ func (c *BackupController) GetBackups(ctx *gin.Context) {
return
}
var request GetBackupsRequest
var request backups_dto.GetBackupsRequest
if err := ctx.ShouldBindQuery(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -81,7 +83,7 @@ func (c *BackupController) GetBackups(ctx *gin.Context) {
// @Tags backups
// @Accept json
// @Produce json
// @Param request body MakeBackupRequest true "Backup creation data"
// @Param request body backups_dto.MakeBackupRequest true "Backup creation data"
// @Success 200 {object} map[string]string
// @Failure 400
// @Failure 401
@@ -94,7 +96,7 @@ func (c *BackupController) MakeBackup(ctx *gin.Context) {
return
}
var request MakeBackupRequest
var request backups_dto.MakeBackupRequest
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -310,10 +312,6 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
c.backupService.WriteAuditLogForDownload(downloadToken.UserID, backup, database)
}
type MakeBackupRequest struct {
DatabaseID uuid.UUID `json:"database_id" binding:"required"`
}
func (c *BackupController) generateBackupFilename(
backup *backups_core.Backup,
database *databases.Database,

View File

@@ -1,4 +1,4 @@
package backups
package backups_controllers
import (
"context"
@@ -24,11 +24,14 @@ import (
backups_common "databasus-backend/internal/features/backups/backups/common"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_download "databasus-backend/internal/features/backups/backups/download"
backups_dto "databasus-backend/internal/features/backups/backups/dto"
backups_services "databasus-backend/internal/features/backups/backups/services"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/postgresql"
"databasus-backend/internal/features/storages"
local_storage "databasus-backend/internal/features/storages/models/local"
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
users_dto "databasus-backend/internal/features/users/dto"
users_enums "databasus-backend/internal/features/users/enums"
users_services "databasus-backend/internal/features/users/services"
@@ -119,7 +122,7 @@ func Test_GetBackups_PermissionsEnforced(t *testing.T) {
)
if tt.expectSuccess {
var response GetBackupsResponse
var response backups_dto.GetBackupsResponse
err := json.Unmarshal(testResp.Body, &response)
assert.NoError(t, err)
assert.GreaterOrEqual(t, len(response.Backups), 1)
@@ -214,7 +217,7 @@ func Test_CreateBackup_PermissionsEnforced(t *testing.T) {
testUserToken = nonMember.Token
}
request := MakeBackupRequest{DatabaseID: database.ID}
request := backups_dto.MakeBackupRequest{DatabaseID: database.ID}
testResp := test_utils.MakePostRequest(
t,
router,
@@ -245,7 +248,7 @@ func Test_CreateBackup_AuditLogWritten(t *testing.T) {
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
enableBackupForDatabase(database.ID)
request := MakeBackupRequest{DatabaseID: database.ID}
request := backups_dto.MakeBackupRequest{DatabaseID: database.ID}
test_utils.MakePostRequest(
t,
router,
@@ -373,7 +376,7 @@ func Test_DeleteBackup_PermissionsEnforced(t *testing.T) {
ownerUser, err := userService.GetUserFromToken(owner.Token)
assert.NoError(t, err)
response, err := GetBackupService().GetBackups(ownerUser, database.ID, 10, 0)
response, err := backups_services.GetBackupService().GetBackups(ownerUser, database.ID, 10, 0)
assert.NoError(t, err)
assert.Equal(t, 0, len(response.Backups))
}
@@ -999,7 +1002,7 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
assert.NoError(t, err)
// Register a cancellable context for the backup
GetBackupService().taskCancelManager.RegisterTask(backup.ID, func() {})
task_cancellation.GetTaskCancelManager().RegisterTask(backup.ID, func() {})
resp := test_utils.MakePostRequest(
t,
@@ -1091,7 +1094,7 @@ func Test_ConcurrentDownloadPrevention(t *testing.T) {
time.Sleep(50 * time.Millisecond)
service := GetBackupService()
service := backups_services.GetBackupService()
if !service.IsDownloadInProgress(owner.UserID) {
t.Log("Warning: First download completed before we could test concurrency")
<-downloadComplete
@@ -1192,7 +1195,7 @@ func Test_GenerateDownloadToken_BlockedWhenDownloadInProgress(t *testing.T) {
time.Sleep(50 * time.Millisecond)
service := GetBackupService()
service := backups_services.GetBackupService()
if !service.IsDownloadInProgress(owner.UserID) {
t.Log("Warning: First download completed before we could test token generation blocking")
<-downloadComplete
@@ -1268,7 +1271,7 @@ func Test_MakeBackup_VerifyBackupAndMetadataFilesExistInStorage(t *testing.T) {
initialBackups, err := backupRepo.FindByDatabaseID(database.ID)
assert.NoError(t, err)
request := MakeBackupRequest{DatabaseID: database.ID}
request := backups_dto.MakeBackupRequest{DatabaseID: database.ID}
test_utils.MakePostRequest(
t,
router,
@@ -1502,7 +1505,7 @@ func createTestBackup(
}
func createExpiredDownloadToken(backupID, userID uuid.UUID) string {
tokenService := GetBackupService().downloadTokenService
tokenService := backups_download.GetDownloadTokenService()
token, err := tokenService.Generate(backupID, userID)
if err != nil {
panic(fmt.Sprintf("Failed to generate download token: %v", err))
@@ -1843,7 +1846,7 @@ func Test_DeleteBackup_RemovesBackupAndMetadataFilesFromDisk(t *testing.T) {
initialBackups, err := backupRepo.FindByDatabaseID(database.ID)
assert.NoError(t, err)
request := MakeBackupRequest{DatabaseID: database.ID}
request := backups_dto.MakeBackupRequest{DatabaseID: database.ID}
test_utils.MakePostRequest(
t,
router,

View File

@@ -0,0 +1,23 @@
package backups_controllers
import (
backups_services "databasus-backend/internal/features/backups/backups/services"
"databasus-backend/internal/features/databases"
)
var backupController = &BackupController{
backups_services.GetBackupService(),
}
func GetBackupController() *BackupController {
return backupController
}
var postgresWalBackupController = &PostgreWalBackupController{
databases.GetDatabaseService(),
backups_services.GetWalService(),
}
func GetPostgresWalBackupController() *PostgreWalBackupController {
return postgresWalBackupController
}

View File

@@ -0,0 +1,291 @@
package backups_controllers
import (
"io"
"net/http"
"strconv"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_dto "databasus-backend/internal/features/backups/backups/dto"
backups_services "databasus-backend/internal/features/backups/backups/services"
"databasus-backend/internal/features/databases"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
// PostgreWalBackupController handles WAL backup endpoints used by the databasus-cli agent.
// Authentication is via a plain agent token in the Authorization header (no Bearer prefix).
type PostgreWalBackupController struct {
databaseService *databases.DatabaseService
walService *backups_services.PostgreWalBackupService
}
func (c *PostgreWalBackupController) RegisterRoutes(router *gin.RouterGroup) {
walRoutes := router.Group("/backups/postgres/wal")
walRoutes.GET("/next-full-backup-time", c.GetNextFullBackupTime)
walRoutes.POST("/error", c.ReportError)
walRoutes.POST("/upload", c.Upload)
walRoutes.GET("/restore/plan", c.GetRestorePlan)
walRoutes.GET("/restore/download", c.DownloadBackupFile)
}
// GetNextFullBackupTime
// @Summary Get next full backup time
// @Description Returns the next scheduled full basebackup time for the authenticated database
// @Tags backups-wal
// @Produce json
// @Security AgentToken
// @Success 200 {object} backups_dto.GetNextFullBackupTimeResponse
// @Failure 401 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /backups/postgres/wal/next-full-backup-time [get]
func (c *PostgreWalBackupController) GetNextFullBackupTime(ctx *gin.Context) {
database, err := c.getDatabase(ctx)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
return
}
response, err := c.walService.GetNextFullBackupTime(database)
if err != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, response)
}
// ReportError
// @Summary Report agent error
// @Description Records a fatal error from the agent against the database record and marks it as errored
// @Tags backups-wal
// @Accept json
// @Security AgentToken
// @Param request body backups_dto.ReportErrorRequest true "Error details"
// @Success 200
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /backups/postgres/wal/error [post]
func (c *PostgreWalBackupController) ReportError(ctx *gin.Context) {
database, err := c.getDatabase(ctx)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
return
}
var request backups_dto.ReportErrorRequest
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := c.walService.ReportError(database, request.Error); err != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
ctx.Status(http.StatusOK)
}
// Upload
// @Summary Stream upload a basebackup or WAL segment
// @Description Accepts a zstd-compressed binary stream and stores it in the database's configured storage.
// The server generates the storage filename; agents do not control the destination path.
// For WAL segment uploads the server validates the WAL chain and returns 409 if a gap is detected
// or 400 if no full backup exists yet (agent should trigger a full basebackup in both cases).
// @Tags backups-wal
// @Accept application/octet-stream
// @Produce json
// @Security AgentToken
// @Param X-Upload-Type header string true "Upload type" Enums(basebackup, wal)
// @Param X-Wal-Segment-Name header string false "24-hex WAL segment identifier (required for wal uploads, e.g. 0000000100000001000000AB)"
// @Param X-Wal-Segment-Size header int false "WAL segment size in bytes reported by the PostgreSQL instance (default: 16777216)"
// @Param fullBackupWalStartSegment query string false "First WAL segment needed to make the basebackup consistent (required for basebackup uploads)"
// @Param fullBackupWalStopSegment query string false "Last WAL segment included in the basebackup (required for basebackup uploads)"
// @Success 204
// @Failure 400 {object} backups_dto.UploadGapResponse "No full backup exists (error: no_full_backup)"
// @Failure 401 {object} map[string]string
// @Failure 409 {object} backups_dto.UploadGapResponse "WAL chain gap detected (error: gap_detected)"
// @Failure 500 {object} map[string]string
// @Router /backups/postgres/wal/upload [post]
func (c *PostgreWalBackupController) Upload(ctx *gin.Context) {
database, err := c.getDatabase(ctx)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
return
}
uploadType := backups_core.PgWalUploadType(ctx.GetHeader("X-Upload-Type"))
if uploadType != backups_core.PgWalUploadTypeBasebackup &&
uploadType != backups_core.PgWalUploadTypeWal {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "X-Upload-Type must be 'basebackup' or 'wal'"},
)
return
}
walSegmentName := ""
if uploadType == backups_core.PgWalUploadTypeWal {
walSegmentName = ctx.GetHeader("X-Wal-Segment-Name")
if walSegmentName == "" {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "X-Wal-Segment-Name is required for wal uploads"},
)
return
}
}
if uploadType == backups_core.PgWalUploadTypeBasebackup {
if ctx.Query("fullBackupWalStartSegment") == "" ||
ctx.Query("fullBackupWalStopSegment") == "" {
ctx.JSON(
http.StatusBadRequest,
gin.H{
"error": "fullBackupWalStartSegment and fullBackupWalStopSegment are required for basebackup uploads",
},
)
return
}
}
walSegmentSizeBytes := int64(0)
if raw := ctx.GetHeader("X-Wal-Segment-Size"); raw != "" {
parsed, parseErr := strconv.ParseInt(raw, 10, 64)
if parseErr != nil || parsed <= 0 {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "X-Wal-Segment-Size must be a positive integer"},
)
return
}
walSegmentSizeBytes = parsed
}
gapResp, uploadErr := c.walService.UploadWal(
ctx.Request.Context(),
database,
uploadType,
walSegmentName,
ctx.Query("fullBackupWalStartSegment"),
ctx.Query("fullBackupWalStopSegment"),
walSegmentSizeBytes,
ctx.Request.Body,
)
if uploadErr != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": uploadErr.Error()})
return
}
if gapResp != nil {
if gapResp.Error == "no_full_backup" {
ctx.JSON(http.StatusBadRequest, gapResp)
return
}
ctx.JSON(http.StatusConflict, gapResp)
return
}
ctx.Status(http.StatusNoContent)
}
// GetRestorePlan
// @Summary Get restore plan
// @Description Resolves the full backup and all required WAL segments needed for recovery. Validates the WAL chain is continuous.
// @Tags backups-wal
// @Produce json
// @Security AgentToken
// @Param backupId query string false "UUID of a specific full backup to restore from; defaults to the most recent"
// @Success 200 {object} backups_dto.GetRestorePlanResponse
// @Failure 400 {object} map[string]string "Broken WAL chain or no backups available"
// @Failure 401 {object} map[string]string
// @Failure 500 {object} map[string]string
// @Router /backups/postgres/wal/restore/plan [get]
func (c *PostgreWalBackupController) GetRestorePlan(ctx *gin.Context) {
database, err := c.getDatabase(ctx)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
return
}
var backupID *uuid.UUID
if raw := ctx.Query("backupId"); raw != "" {
parsed, parseErr := uuid.Parse(raw)
if parseErr != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backupId format"})
return
}
backupID = &parsed
}
response, planErr, err := c.walService.GetRestorePlan(database, backupID)
if err != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if planErr != nil {
ctx.JSON(http.StatusBadRequest, planErr)
return
}
ctx.JSON(http.StatusOK, response)
}
// DownloadBackupFile
// @Summary Download a backup or WAL segment file for restore
// @Description Retrieves the backup file by ID (validated against the authenticated database), decrypts it server-side if encrypted, and streams the zstd-compressed result to the agent
// @Tags backups-wal
// @Produce application/octet-stream
// @Security AgentToken
// @Param backupId query string true "Backup ID from the restore plan response"
// @Success 200 {file} file
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Router /backups/postgres/wal/restore/download [get]
func (c *PostgreWalBackupController) DownloadBackupFile(ctx *gin.Context) {
database, err := c.getDatabase(ctx)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid agent token"})
return
}
backupIDRaw := ctx.Query("backupId")
if backupIDRaw == "" {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "backupId is required"})
return
}
backupID, err := uuid.Parse(backupIDRaw)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backupId format"})
return
}
reader, err := c.walService.DownloadBackupFile(database, backupID)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
defer func() { _ = reader.Close() }()
ctx.Header("Content-Type", "application/octet-stream")
ctx.Status(http.StatusOK)
_, _ = io.Copy(ctx.Writer, reader)
}
func (c *PostgreWalBackupController) getDatabase(
ctx *gin.Context,
) (*databases.Database, error) {
token := ctx.GetHeader("Authorization")
return c.databaseService.GetDatabaseByAgentToken(token)
}

View File

@@ -1,4 +1,4 @@
package backups
package backups_controllers
import (
"testing"
@@ -41,7 +41,7 @@ func WaitForBackupCompletion(
deadline := time.Now().UTC().Add(timeout)
for time.Now().UTC().Before(deadline) {
backups, err := backupRepository.FindByDatabaseID(databaseID)
backups, err := backups_core.GetBackupRepository().FindByDatabaseID(databaseID)
if err != nil {
t.Logf("WaitForBackupCompletion: error finding backups: %v", err)
time.Sleep(50 * time.Millisecond)

View File

@@ -0,0 +1,7 @@
package backups_core
var backupRepository = &BackupRepository{}
func GetBackupRepository() *BackupRepository {
return backupRepository
}

View File

@@ -8,3 +8,10 @@ const (
BackupStatusFailed BackupStatus = "FAILED"
BackupStatusCanceled BackupStatus = "CANCELED"
)
type PgWalUploadType string
const (
PgWalUploadTypeBasebackup PgWalUploadType = "basebackup"
PgWalUploadTypeWal PgWalUploadType = "wal"
)

View File

@@ -1,12 +1,22 @@
package backups_core
import (
backups_config "databasus-backend/internal/features/backups/config"
"fmt"
"time"
backups_config "databasus-backend/internal/features/backups/config"
files_utils "databasus-backend/internal/util/files"
"github.com/google/uuid"
)
type PgWalBackupType string
const (
PgWalBackupTypeFullBackup PgWalBackupType = "PG_FULL_BACKUP"
PgWalBackupTypeWalSegment PgWalBackupType = "PG_WAL_SEGMENT"
)
type Backup struct {
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
FileName string `json:"fileName" gorm:"column:file_name;type:text;not null"`
@@ -26,5 +36,23 @@ type Backup struct {
EncryptionIV *string `json:"-" gorm:"column:encryption_iv"`
Encryption backups_config.BackupEncryption `json:"encryption" gorm:"column:encryption;type:text;not null;default:'NONE'"`
// Postgres WAL backup specific fields
PgWalBackupType *PgWalBackupType `json:"pgWalBackupType" gorm:"column:pg_wal_backup_type;type:text"`
PgFullBackupWalStartSegmentName *string `json:"pgFullBackupWalStartSegmentName" gorm:"column:pg_wal_start_segment;type:text"`
PgFullBackupWalStopSegmentName *string `json:"pgFullBackupWalStopSegmentName" gorm:"column:pg_wal_stop_segment;type:text"`
PgVersion *string `json:"pgVersion" gorm:"column:pg_version;type:text"`
PgWalSegmentName *string `json:"pgWalSegmentName" gorm:"column:pg_wal_segment_name;type:text"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
}
func (b *Backup) GenerateFilename(dbName string) {
timestamp := time.Now().UTC()
b.FileName = fmt.Sprintf(
"%s-%s-%s",
files_utils.SanitizeFilename(dbName),
timestamp.Format("20060102-150405"),
b.ID.String(),
)
}

View File

@@ -245,3 +245,134 @@ func (r *BackupRepository) FindOldestByDatabaseExcludingInProgress(
return backups, nil
}
func (r *BackupRepository) FindCompletedFullWalBackupByID(
databaseID uuid.UUID,
backupID uuid.UUID,
) (*Backup, error) {
var backup Backup
err := storage.
GetDb().
Where(
"database_id = ? AND id = ? AND pg_wal_backup_type = ? AND status = ?",
databaseID,
backupID,
PgWalBackupTypeFullBackup,
BackupStatusCompleted,
).
First(&backup).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &backup, nil
}
func (r *BackupRepository) FindCompletedWalSegmentsAfter(
databaseID uuid.UUID,
afterSegmentName string,
) ([]*Backup, error) {
var backups []*Backup
err := storage.
GetDb().
Where(
"database_id = ? AND pg_wal_backup_type = ? AND pg_wal_segment_name >= ? AND status = ?",
databaseID,
PgWalBackupTypeWalSegment,
afterSegmentName,
BackupStatusCompleted,
).
Order("pg_wal_segment_name ASC").
Find(&backups).Error
if err != nil {
return nil, err
}
return backups, nil
}
func (r *BackupRepository) FindLastCompletedFullWalBackupByDatabaseID(
databaseID uuid.UUID,
) (*Backup, error) {
var backup Backup
err := storage.
GetDb().
Where(
"database_id = ? AND pg_wal_backup_type = ? AND status = ?",
databaseID,
PgWalBackupTypeFullBackup,
BackupStatusCompleted,
).
Order("created_at DESC").
First(&backup).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &backup, nil
}
func (r *BackupRepository) FindWalSegmentByName(
databaseID uuid.UUID,
segmentName string,
) (*Backup, error) {
var backup Backup
err := storage.
GetDb().
Where(
"database_id = ? AND pg_wal_backup_type = ? AND pg_wal_segment_name = ?",
databaseID,
PgWalBackupTypeWalSegment,
segmentName,
).
First(&backup).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &backup, nil
}
func (r *BackupRepository) FindLastWalSegmentAfter(
databaseID uuid.UUID,
afterSegmentName string,
) (*Backup, error) {
var backup Backup
err := storage.
GetDb().
Where(
"database_id = ? AND pg_wal_backup_type = ? AND pg_wal_segment_name > ? AND status = ?",
databaseID,
PgWalBackupTypeWalSegment,
afterSegmentName,
BackupStatusCompleted,
).
Order("pg_wal_segment_name DESC").
First(&backup).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &backup, nil
}

View File

@@ -1,29 +0,0 @@
package backups
import (
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/encryption"
"io"
)
type GetBackupsRequest struct {
DatabaseID string `form:"database_id" binding:"required"`
Limit int `form:"limit"`
Offset int `form:"offset"`
}
type GetBackupsResponse struct {
Backups []*backups_core.Backup `json:"backups"`
Total int64 `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
}
type DecryptionReaderCloser struct {
*encryption.DecryptionReader
BaseReader io.ReadCloser
}
func (r *DecryptionReaderCloser) Close() error {
return r.BaseReader.Close()
}

View File

@@ -0,0 +1,78 @@
package backups_dto
import (
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/encryption"
"io"
"time"
"github.com/google/uuid"
)
type GetBackupsRequest struct {
DatabaseID string `form:"database_id" binding:"required"`
Limit int `form:"limit"`
Offset int `form:"offset"`
}
type GetBackupsResponse struct {
Backups []*backups_core.Backup `json:"backups"`
Total int64 `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
}
type DecryptionReaderCloser struct {
*encryption.DecryptionReader
BaseReader io.ReadCloser
}
func (r *DecryptionReaderCloser) Close() error {
return r.BaseReader.Close()
}
type MakeBackupRequest struct {
DatabaseID uuid.UUID `json:"database_id" binding:"required"`
}
type GetNextFullBackupTimeResponse struct {
NextFullBackupTime *time.Time `json:"nextFullBackupTime"`
}
type ReportErrorRequest struct {
Error string `json:"error" binding:"required"`
}
type UploadGapResponse struct {
Error string `json:"error"`
ExpectedSegmentName string `json:"expectedSegmentName"`
ReceivedSegmentName string `json:"receivedSegmentName"`
}
type RestorePlanFullBackup struct {
BackupID uuid.UUID `json:"id"`
FullBackupWalStartSegment string `json:"fullBackupWalStartSegment"`
FullBackupWalStopSegment string `json:"fullBackupWalStopSegment"`
PgVersion string `json:"pgVersion"`
CreatedAt time.Time `json:"createdAt"`
SizeBytes int64 `json:"sizeBytes"`
}
type RestorePlanWalSegment struct {
BackupID uuid.UUID `json:"backupId"`
SegmentName string `json:"segmentName"`
SizeBytes int64 `json:"sizeBytes"`
}
type GetRestorePlanErrorResponse struct {
Error string `json:"error"`
Message string `json:"message"`
LastContiguousSegment string `json:"lastContiguousSegment,omitempty"`
}
type GetRestorePlanResponse struct {
FullBackup RestorePlanFullBackup `json:"fullBackup"`
WalSegments []RestorePlanWalSegment `json:"walSegments"`
TotalSizeBytes int64 `json:"totalSizeBytes"`
LatestAvailableSegment string `json:"latestAvailableSegment"`
}

View File

@@ -0,0 +1,45 @@
package encryption
import (
"encoding/base64"
"fmt"
"io"
"github.com/google/uuid"
)
// EncryptionSetup holds the result of setting up encryption for a backup stream.
type EncryptionSetup struct {
Writer *EncryptionWriter
SaltBase64 string
NonceBase64 string
}
// SetupEncryptionWriter generates salt/nonce, creates an EncryptionWriter, and
// returns the base64-encoded salt and nonce for storage on the backup record.
func SetupEncryptionWriter(
baseWriter io.Writer,
masterKey string,
backupID uuid.UUID,
) (*EncryptionSetup, error) {
salt, err := GenerateSalt()
if err != nil {
return nil, fmt.Errorf("failed to generate salt: %w", err)
}
nonce, err := GenerateNonce()
if err != nil {
return nil, fmt.Errorf("failed to generate nonce: %w", err)
}
encWriter, err := NewEncryptionWriter(baseWriter, masterKey, backupID, salt, nonce)
if err != nil {
return nil, fmt.Errorf("failed to create encryption writer: %w", err)
}
return &EncryptionSetup{
Writer: encWriter,
SaltBase64: base64.StdEncoding.EncodeToString(salt),
NonceBase64: base64.StdEncoding.EncodeToString(nonce),
}, nil
}

View File

@@ -1,9 +1,6 @@
package backups
package backups_services
import (
"sync"
"sync/atomic"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups/backuping"
backups_core "databasus-backend/internal/features/backups/backups/core"
@@ -18,16 +15,16 @@ import (
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
"sync"
"sync/atomic"
)
var backupRepository = &backups_core.BackupRepository{}
var taskCancelManager = task_cancellation.GetTaskCancelManager()
var backupService = &BackupService{
databases.GetDatabaseService(),
storages.GetStorageService(),
backupRepository,
backups_core.GetBackupRepository(),
notifiers.GetNotifierService(),
notifiers.GetNotifierService(),
backups_config.GetBackupConfigService(),
@@ -44,16 +41,21 @@ var backupService = &BackupService{
backuping.GetBackupCleaner(),
}
var backupController = &BackupController{
backupService: backupService,
}
func GetBackupService() *BackupService {
return backupService
}
func GetBackupController() *BackupController {
return backupController
var walService = &PostgreWalBackupService{
backups_config.GetBackupConfigService(),
backups_core.GetBackupRepository(),
encryption.GetFieldEncryptor(),
encryption_secrets.GetSecretKeyService(),
logger.GetLogger(),
backupService,
}
func GetWalService() *PostgreWalBackupService {
return walService
}
var (

View File

@@ -0,0 +1,613 @@
package backups_services
import (
"context"
"fmt"
"io"
"log/slog"
"time"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_dto "databasus-backend/internal/features/backups/backups/dto"
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/postgresql"
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
util_encryption "databasus-backend/internal/util/encryption"
util_wal "databasus-backend/internal/util/wal"
"github.com/google/uuid"
)
// PostgreWalBackupService handles WAL segment and basebackup uploads from the databasus-cli agent.
type PostgreWalBackupService struct {
backupConfigService *backups_config.BackupConfigService
backupRepository *backups_core.BackupRepository
fieldEncryptor util_encryption.FieldEncryptor
secretKeyService *encryption_secrets.SecretKeyService
logger *slog.Logger
backupService *BackupService
}
// UploadWal accepts a streaming WAL segment or basebackup upload from the agent.
// For WAL segments it validates the WAL chain before accepting. Returns an UploadGapResponse
// (409) when the chain is broken so the agent knows to trigger a full basebackup.
func (s *PostgreWalBackupService) UploadWal(
ctx context.Context,
database *databases.Database,
uploadType backups_core.PgWalUploadType,
walSegmentName string,
fullBackupWalStartSegment string,
fullBackupWalStopSegment string,
walSegmentSizeBytes int64,
body io.Reader,
) (*backups_dto.UploadGapResponse, error) {
if err := s.validateWalBackupType(database); err != nil {
return nil, err
}
if uploadType == backups_core.PgWalUploadTypeBasebackup {
if fullBackupWalStartSegment == "" || fullBackupWalStopSegment == "" {
return nil, fmt.Errorf(
"fullBackupWalStartSegment and fullBackupWalStopSegment are required for basebackup uploads",
)
}
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
if err != nil {
return nil, fmt.Errorf("failed to get backup config: %w", err)
}
if backupConfig.Storage == nil {
return nil, fmt.Errorf("no storage configured for database %s", database.ID)
}
if uploadType == backups_core.PgWalUploadTypeWal {
// Idempotency: check before chain validation so a successful re-upload is
// not misidentified as a gap.
existing, err := s.backupRepository.FindWalSegmentByName(database.ID, walSegmentName)
if err != nil {
return nil, fmt.Errorf("failed to check for duplicate WAL segment: %w", err)
}
if existing != nil {
return nil, nil
}
gapResp, err := s.validateWalChain(database.ID, walSegmentName, walSegmentSizeBytes)
if err != nil {
return nil, err
}
if gapResp != nil {
return gapResp, nil
}
}
backup := s.createBackupRecord(
database.ID,
backupConfig.Storage.ID,
uploadType,
database.Name,
walSegmentName,
fullBackupWalStartSegment,
fullBackupWalStopSegment,
backupConfig.Encryption,
)
if err := s.backupRepository.Save(backup); err != nil {
return nil, fmt.Errorf("failed to create backup record: %w", err)
}
sizeBytes, streamErr := s.streamToStorage(ctx, backup, backupConfig, body)
if streamErr != nil {
errMsg := streamErr.Error()
s.markFailed(backup, errMsg)
return nil, fmt.Errorf("upload failed: %w", streamErr)
}
s.markCompleted(backup, sizeBytes)
return nil, nil
}
func (s *PostgreWalBackupService) GetRestorePlan(
database *databases.Database,
backupID *uuid.UUID,
) (*backups_dto.GetRestorePlanResponse, *backups_dto.GetRestorePlanErrorResponse, error) {
if err := s.validateWalBackupType(database); err != nil {
return nil, nil, err
}
fullBackup, err := s.resolveFullBackup(database.ID, backupID)
if err != nil {
return nil, nil, err
}
if fullBackup == nil {
msg := "no full backups available for this database"
if backupID != nil {
msg = fmt.Sprintf("full backup %s not found or not completed", backupID)
}
return nil, &backups_dto.GetRestorePlanErrorResponse{
Error: "no_backups",
Message: msg,
}, nil
}
startSegment := ""
if fullBackup.PgFullBackupWalStartSegmentName != nil {
startSegment = *fullBackup.PgFullBackupWalStartSegmentName
}
walSegments, err := s.backupRepository.FindCompletedWalSegmentsAfter(database.ID, startSegment)
if err != nil {
return nil, nil, fmt.Errorf("failed to query WAL segments: %w", err)
}
chainErr := s.validateRestoreWalChain(fullBackup, walSegments)
if chainErr != nil {
return nil, chainErr, nil
}
fullBackupSizeBytes := int64(fullBackup.BackupSizeMb * 1024 * 1024)
pgVersion := ""
if fullBackup.PgVersion != nil {
pgVersion = *fullBackup.PgVersion
}
stopSegment := ""
if fullBackup.PgFullBackupWalStopSegmentName != nil {
stopSegment = *fullBackup.PgFullBackupWalStopSegmentName
}
response := &backups_dto.GetRestorePlanResponse{
FullBackup: backups_dto.RestorePlanFullBackup{
BackupID: fullBackup.ID,
FullBackupWalStartSegment: startSegment,
FullBackupWalStopSegment: stopSegment,
PgVersion: pgVersion,
CreatedAt: fullBackup.CreatedAt,
SizeBytes: fullBackupSizeBytes,
},
TotalSizeBytes: fullBackupSizeBytes,
}
for _, seg := range walSegments {
segName := ""
if seg.PgWalSegmentName != nil {
segName = *seg.PgWalSegmentName
}
segSizeBytes := int64(seg.BackupSizeMb * 1024 * 1024)
response.WalSegments = append(response.WalSegments, backups_dto.RestorePlanWalSegment{
BackupID: seg.ID,
SegmentName: segName,
SizeBytes: segSizeBytes,
})
response.TotalSizeBytes += segSizeBytes
response.LatestAvailableSegment = segName
}
return response, nil, nil
}
// DownloadBackupFile returns a reader for a backup file belonging to the given database.
// Decryption is handled transparently if the backup is encrypted.
func (s *PostgreWalBackupService) DownloadBackupFile(
database *databases.Database,
backupID uuid.UUID,
) (io.ReadCloser, error) {
if err := s.validateWalBackupType(database); err != nil {
return nil, err
}
backup, err := s.backupRepository.FindByID(backupID)
if err != nil {
return nil, fmt.Errorf("backup not found: %w", err)
}
if backup.DatabaseID != database.ID {
return nil, fmt.Errorf("backup does not belong to this database")
}
if backup.Status != backups_core.BackupStatusCompleted {
return nil, fmt.Errorf("backup is not completed")
}
return s.backupService.GetBackupReader(backupID)
}
func (s *PostgreWalBackupService) validateWalChain(
databaseID uuid.UUID,
incomingSegment string,
walSegmentSizeBytes int64,
) (*backups_dto.UploadGapResponse, error) {
fullBackup, err := s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(databaseID)
if err != nil {
return nil, fmt.Errorf("failed to query full backup: %w", err)
}
// No full backup exists yet: cannot accept WAL segments without a chain anchor.
if fullBackup == nil || fullBackup.PgFullBackupWalStopSegmentName == nil {
return &backups_dto.UploadGapResponse{
Error: "no_full_backup",
ExpectedSegmentName: "",
ReceivedSegmentName: incomingSegment,
}, nil
}
stopSegment := *fullBackup.PgFullBackupWalStopSegmentName
lastWal, err := s.backupRepository.FindLastWalSegmentAfter(databaseID, stopSegment)
if err != nil {
return nil, fmt.Errorf("failed to query last WAL segment: %w", err)
}
walCalculator := util_wal.NewWalCalculator(walSegmentSizeBytes)
var chainTail string
if lastWal != nil && lastWal.PgWalSegmentName != nil {
chainTail = *lastWal.PgWalSegmentName
} else {
chainTail = stopSegment
}
expectedNext, err := walCalculator.NextSegment(chainTail)
if err != nil {
return nil, fmt.Errorf("WAL arithmetic failed for %q: %w", chainTail, err)
}
if incomingSegment != expectedNext {
return &backups_dto.UploadGapResponse{
Error: "gap_detected",
ExpectedSegmentName: expectedNext,
ReceivedSegmentName: incomingSegment,
}, nil
}
return nil, nil
}
func (s *PostgreWalBackupService) createBackupRecord(
databaseID uuid.UUID,
storageID uuid.UUID,
uploadType backups_core.PgWalUploadType,
dbName string,
walSegmentName string,
fullBackupWalStartSegment string,
fullBackupWalStopSegment string,
encryption backups_config.BackupEncryption,
) *backups_core.Backup {
now := time.Now().UTC()
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: databaseID,
StorageID: storageID,
Status: backups_core.BackupStatusInProgress,
Encryption: encryption,
CreatedAt: now,
}
backup.GenerateFilename(dbName)
if uploadType == backups_core.PgWalUploadTypeBasebackup {
walBackupType := backups_core.PgWalBackupTypeFullBackup
backup.PgWalBackupType = &walBackupType
if fullBackupWalStartSegment != "" {
backup.PgFullBackupWalStartSegmentName = &fullBackupWalStartSegment
}
if fullBackupWalStopSegment != "" {
backup.PgFullBackupWalStopSegmentName = &fullBackupWalStopSegment
}
} else {
walBackupType := backups_core.PgWalBackupTypeWalSegment
backup.PgWalBackupType = &walBackupType
backup.PgWalSegmentName = &walSegmentName
}
return backup
}
func (s *PostgreWalBackupService) streamToStorage(
ctx context.Context,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
body io.Reader,
) (int64, error) {
if backupConfig.Encryption == backups_config.BackupEncryptionEncrypted {
return s.streamEncrypted(ctx, backup, backupConfig, body, backup.FileName)
}
return s.streamDirect(ctx, backupConfig, body, backup.FileName)
}
func (s *PostgreWalBackupService) streamDirect(
ctx context.Context,
backupConfig *backups_config.BackupConfig,
body io.Reader,
fileName string,
) (int64, error) {
cr := &countingReader{r: body}
if err := backupConfig.Storage.SaveFile(ctx, s.fieldEncryptor, s.logger, fileName, cr); err != nil {
return 0, err
}
return cr.n, nil
}
func (s *PostgreWalBackupService) streamEncrypted(
ctx context.Context,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
body io.Reader,
fileName string,
) (int64, error) {
masterKey, err := s.secretKeyService.GetSecretKey()
if err != nil {
return 0, fmt.Errorf("failed to get master encryption key: %w", err)
}
pipeReader, pipeWriter := io.Pipe()
encryptionSetup, err := backup_encryption.SetupEncryptionWriter(
pipeWriter,
masterKey,
backup.ID,
)
if err != nil {
_ = pipeWriter.Close()
return 0, err
}
copyErrCh := make(chan error, 1)
go func() {
_, copyErr := io.Copy(encryptionSetup.Writer, body)
if copyErr != nil {
_ = encryptionSetup.Writer.Close()
_ = pipeWriter.CloseWithError(copyErr)
copyErrCh <- copyErr
return
}
if closeErr := encryptionSetup.Writer.Close(); closeErr != nil {
_ = pipeWriter.CloseWithError(closeErr)
copyErrCh <- closeErr
return
}
copyErrCh <- pipeWriter.Close()
}()
cr := &countingReader{r: pipeReader}
saveErr := backupConfig.Storage.SaveFile(ctx, s.fieldEncryptor, s.logger, fileName, cr)
copyErr := <-copyErrCh
if copyErr != nil {
return 0, copyErr
}
if saveErr != nil {
return 0, saveErr
}
backup.EncryptionSalt = &encryptionSetup.SaltBase64
backup.EncryptionIV = &encryptionSetup.NonceBase64
return cr.n, nil
}
func (s *PostgreWalBackupService) markCompleted(backup *backups_core.Backup, sizeBytes int64) {
backup.Status = backups_core.BackupStatusCompleted
backup.BackupSizeMb = float64(sizeBytes) / (1024 * 1024)
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error(
"failed to mark WAL backup as completed",
"backupId",
backup.ID,
"error",
err,
)
}
}
func (s *PostgreWalBackupService) markFailed(backup *backups_core.Backup, errMsg string) {
backup.Status = backups_core.BackupStatusFailed
backup.FailMessage = &errMsg
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error("failed to mark WAL backup as failed", "backupId", backup.ID, "error", err)
}
}
func (s *PostgreWalBackupService) GetNextFullBackupTime(
database *databases.Database,
) (*backups_dto.GetNextFullBackupTimeResponse, error) {
if err := s.validateWalBackupType(database); err != nil {
return nil, err
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
if err != nil {
return nil, fmt.Errorf("failed to get backup config: %w", err)
}
if backupConfig.BackupInterval == nil {
return nil, fmt.Errorf("no backup interval configured for database %s", database.ID)
}
lastFullBackup, err := s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(
database.ID,
)
if err != nil {
return nil, fmt.Errorf("failed to query last full backup: %w", err)
}
var lastBackupTime *time.Time
if lastFullBackup != nil {
lastBackupTime = &lastFullBackup.CreatedAt
}
now := time.Now().UTC()
nextTime := backupConfig.BackupInterval.NextTriggerTime(now, lastBackupTime)
return &backups_dto.GetNextFullBackupTimeResponse{
NextFullBackupTime: nextTime,
}, nil
}
// ReportError creates a FAILED backup record with the agent's error message.
func (s *PostgreWalBackupService) ReportError(
database *databases.Database,
errorMsg string,
) error {
if err := s.validateWalBackupType(database); err != nil {
return err
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
if err != nil {
return fmt.Errorf("failed to get backup config: %w", err)
}
if backupConfig.Storage == nil {
return fmt.Errorf("no storage configured for database %s", database.ID)
}
now := time.Now().UTC()
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: backupConfig.Storage.ID,
Status: backups_core.BackupStatusFailed,
FailMessage: &errorMsg,
Encryption: backupConfig.Encryption,
CreatedAt: now,
}
backup.GenerateFilename(database.Name)
if err := s.backupRepository.Save(backup); err != nil {
return fmt.Errorf("failed to save error backup record: %w", err)
}
return nil
}
func (s *PostgreWalBackupService) resolveFullBackup(
databaseID uuid.UUID,
backupID *uuid.UUID,
) (*backups_core.Backup, error) {
if backupID != nil {
return s.backupRepository.FindCompletedFullWalBackupByID(databaseID, *backupID)
}
return s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(databaseID)
}
func (s *PostgreWalBackupService) validateRestoreWalChain(
fullBackup *backups_core.Backup,
walSegments []*backups_core.Backup,
) *backups_dto.GetRestorePlanErrorResponse {
if len(walSegments) == 0 {
return nil
}
stopSegment := ""
if fullBackup.PgFullBackupWalStopSegmentName != nil {
stopSegment = *fullBackup.PgFullBackupWalStopSegmentName
}
walCalculator := util_wal.NewWalCalculator(0)
expectedNext, err := walCalculator.NextSegment(stopSegment)
if err != nil {
return nil
}
for _, seg := range walSegments {
segName := ""
if seg.PgWalSegmentName != nil {
segName = *seg.PgWalSegmentName
}
cmp, cmpErr := walCalculator.Compare(segName, stopSegment)
if cmpErr != nil {
return nil
}
// Skip segments that are <= stopSegment (they are part of the basebackup range)
if cmp <= 0 {
continue
}
if segName != expectedNext {
lastContiguous := stopSegment
// Walk back to find the segment before the gap
for _, prev := range walSegments {
prevName := ""
if prev.PgWalSegmentName != nil {
prevName = *prev.PgWalSegmentName
}
prevCmp, _ := walCalculator.Compare(prevName, stopSegment)
if prevCmp <= 0 {
continue
}
if prevName == segName {
break
}
lastContiguous = prevName
}
return &backups_dto.GetRestorePlanErrorResponse{
Error: "wal_chain_broken",
Message: fmt.Sprintf(
"WAL chain has a gap after segment %s. Recovery is only possible up to this segment.",
lastContiguous,
),
LastContiguousSegment: lastContiguous,
}
}
expectedNext, err = walCalculator.NextSegment(segName)
if err != nil {
return nil
}
}
return nil
}
func (s *PostgreWalBackupService) validateWalBackupType(database *databases.Database) error {
if database.Postgresql == nil ||
database.Postgresql.BackupType != postgresql.PostgresBackupTypeWalV1 {
return fmt.Errorf("database %s is not configured for WAL backups", database.ID)
}
return nil
}
type countingReader struct {
r io.Reader
n int64
}
func (cr *countingReader) Read(p []byte) (n int, err error) {
n, err = cr.r.Read(p)
cr.n += int64(n)
return
}

View File

@@ -1,4 +1,4 @@
package backups
package backups_services
import (
"encoding/base64"
@@ -11,6 +11,7 @@ import (
"databasus-backend/internal/features/backups/backups/backuping"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_download "databasus-backend/internal/features/backups/backups/download"
backups_dto "databasus-backend/internal/features/backups/backups/dto"
"databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -108,7 +109,7 @@ func (s *BackupService) GetBackups(
user *users_models.User,
databaseID uuid.UUID,
limit, offset int,
) (*GetBackupsResponse, error) {
) (*backups_dto.GetBackupsResponse, error) {
database, err := s.databaseService.GetDatabaseByID(databaseID)
if err != nil {
return nil, err
@@ -143,7 +144,7 @@ func (s *BackupService) GetBackups(
return nil, err
}
return &GetBackupsResponse{
return &backups_dto.GetBackupsResponse{
Backups: backups,
Total: total,
Limit: limit,
@@ -274,7 +275,7 @@ func (s *BackupService) GetBackupFile(
database.WorkspaceID,
)
reader, err := s.getBackupReader(backupID)
reader, err := s.GetBackupReader(backupID)
if err != nil {
return nil, nil, nil, err
}
@@ -282,39 +283,9 @@ func (s *BackupService) GetBackupFile(
return reader, backup, database, nil
}
func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
dbBackupsInProgress, err := s.backupRepository.FindByDatabaseIdAndStatus(
databaseID,
backups_core.BackupStatusInProgress,
)
if err != nil {
return err
}
if len(dbBackupsInProgress) > 0 {
return errors.New("backup is in progress, storage cannot be removed")
}
dbBackups, err := s.backupRepository.FindByDatabaseID(
databaseID,
)
if err != nil {
return err
}
for _, dbBackup := range dbBackups {
err := s.backupCleaner.DeleteBackup(dbBackup)
if err != nil {
return err
}
}
return nil
}
// GetBackupReader returns a reader for the backup file
// If encrypted, wraps with DecryptionReader
func (s *BackupService) getBackupReader(backupID uuid.UUID) (io.ReadCloser, error) {
// GetBackupReader returns a reader for the backup file.
// If encrypted, wraps with DecryptionReader.
func (s *BackupService) GetBackupReader(backupID uuid.UUID) (io.ReadCloser, error) {
backup, err := s.backupRepository.FindByID(backupID)
if err != nil {
return nil, fmt.Errorf("failed to find backup: %w", err)
@@ -394,7 +365,7 @@ func (s *BackupService) getBackupReader(backupID uuid.UUID) (io.ReadCloser, erro
s.logger.Info("Returning encrypted backup with decryption", "backupId", backupID)
return &DecryptionReaderCloser{
return &backups_dto.DecryptionReaderCloser{
DecryptionReader: decryptionReader,
BaseReader: fileReader,
}, nil
@@ -465,7 +436,7 @@ func (s *BackupService) GetBackupFileWithoutAuth(
return nil, nil, nil, err
}
reader, err := s.getBackupReader(backupID)
reader, err := s.GetBackupReader(backupID)
if err != nil {
return nil, nil, nil, err
}
@@ -501,6 +472,36 @@ func (s *BackupService) UnregisterDownload(userID uuid.UUID) {
s.downloadTokenService.UnregisterDownload(userID)
}
func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
dbBackupsInProgress, err := s.backupRepository.FindByDatabaseIdAndStatus(
databaseID,
backups_core.BackupStatusInProgress,
)
if err != nil {
return err
}
if len(dbBackupsInProgress) > 0 {
return errors.New("backup is in progress, storage cannot be removed")
}
dbBackups, err := s.backupRepository.FindByDatabaseID(
databaseID,
)
if err != nil {
return err
}
for _, dbBackup := range dbBackups {
err := s.backupCleaner.DeleteBackup(dbBackup)
if err != nil {
return err
}
}
return nil
}
func (s *BackupService) generateBackupFilename(
backup *backups_core.Backup,
database *databases.Database,

View File

@@ -2,7 +2,6 @@ package usecases_mariadb
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
@@ -123,6 +122,10 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpArgs(
args = append(args, "--compress")
if !config.GetEnv().IsCloud {
args = append(args, "--max-allowed-packet=1G")
}
if mdb.IsHttps {
args = append(args, "--ssl")
args = append(args, "--skip-ssl-verify-server-cert")
@@ -437,40 +440,22 @@ func (uc *CreateMariadbBackupUsecase) setupBackupEncryption(
return storageWriter, nil, metadata, nil
}
salt, err := backup_encryption.GenerateSalt()
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to generate salt: %w", err)
}
nonce, err := backup_encryption.GenerateNonce()
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to generate nonce: %w", err)
}
masterKey, err := uc.secretKeyService.GetSecretKey()
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to get master key: %w", err)
}
encWriter, err := backup_encryption.NewEncryptionWriter(
storageWriter,
masterKey,
backupID,
salt,
nonce,
)
encSetup, err := backup_encryption.SetupEncryptionWriter(storageWriter, masterKey, backupID)
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to create encrypting writer: %w", err)
return nil, nil, metadata, err
}
saltBase64 := base64.StdEncoding.EncodeToString(salt)
nonceBase64 := base64.StdEncoding.EncodeToString(nonce)
metadata.EncryptionSalt = &saltBase64
metadata.EncryptionIV = &nonceBase64
metadata.EncryptionSalt = &encSetup.SaltBase64
metadata.EncryptionIV = &encSetup.NonceBase64
metadata.Encryption = backups_config.BackupEncryptionEncrypted
uc.logger.Info("Encryption enabled for backup", "backupId", backupID)
return encWriter, encWriter, metadata, nil
return encSetup.Writer, encSetup.Writer, metadata, nil
}
func (uc *CreateMariadbBackupUsecase) cleanupOnCancellation(

View File

@@ -2,7 +2,6 @@ package usecases_mongodb
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
@@ -277,41 +276,21 @@ func (uc *CreateMongodbBackupUsecase) setupBackupEncryption(
return storageWriter, nil, backupMetadata, nil
}
salt, err := backup_encryption.GenerateSalt()
if err != nil {
return nil, nil, backupMetadata, fmt.Errorf("failed to generate salt: %w", err)
}
nonce, err := backup_encryption.GenerateNonce()
if err != nil {
return nil, nil, backupMetadata, fmt.Errorf("failed to generate nonce: %w", err)
}
masterKey, err := uc.secretKeyService.GetSecretKey()
if err != nil {
return nil, nil, backupMetadata, fmt.Errorf("failed to get master key: %w", err)
}
encryptionWriter, err := backup_encryption.NewEncryptionWriter(
storageWriter,
masterKey,
backupID,
salt,
nonce,
)
encSetup, err := backup_encryption.SetupEncryptionWriter(storageWriter, masterKey, backupID)
if err != nil {
return nil, nil, backupMetadata, fmt.Errorf("failed to create encryption writer: %w", err)
return nil, nil, backupMetadata, err
}
saltBase64 := base64.StdEncoding.EncodeToString(salt)
nonceBase64 := base64.StdEncoding.EncodeToString(nonce)
backupMetadata.BackupID = backupID
backupMetadata.Encryption = backups_config.BackupEncryptionEncrypted
backupMetadata.EncryptionSalt = &saltBase64
backupMetadata.EncryptionIV = &nonceBase64
backupMetadata.EncryptionSalt = &encSetup.SaltBase64
backupMetadata.EncryptionIV = &encSetup.NonceBase64
return encryptionWriter, encryptionWriter, backupMetadata, nil
return encSetup.Writer, encSetup.Writer, backupMetadata, nil
}
func (uc *CreateMongodbBackupUsecase) copyWithShutdownCheck(

View File

@@ -2,7 +2,6 @@ package usecases_mysql
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
@@ -119,7 +118,11 @@ func (uc *CreateMysqlBackupUsecase) buildMysqldumpArgs(my *mysqltypes.MysqlDatab
args = append(args, "--events")
}
args = append(args, uc.getNetworkCompressionArgs(my.Version)...)
args = append(args, uc.getNetworkCompressionArgs(my)...)
if !config.GetEnv().IsCloud {
args = append(args, "--max-allowed-packet=1G")
}
if my.IsHttps {
args = append(args, "--ssl-mode=REQUIRED")
@@ -132,15 +135,21 @@ func (uc *CreateMysqlBackupUsecase) buildMysqldumpArgs(my *mysqltypes.MysqlDatab
return args
}
func (uc *CreateMysqlBackupUsecase) getNetworkCompressionArgs(version tools.MysqlVersion) []string {
func (uc *CreateMysqlBackupUsecase) getNetworkCompressionArgs(
my *mysqltypes.MysqlDatabase,
) []string {
const zstdCompressionLevel = 5
switch version {
switch my.Version {
case tools.MysqlVersion80, tools.MysqlVersion84, tools.MysqlVersion9:
return []string{
"--compression-algorithms=zstd",
fmt.Sprintf("--zstd-compression-level=%d", zstdCompressionLevel),
if my.IsZstdSupported {
return []string{
"--compression-algorithms=zstd",
fmt.Sprintf("--zstd-compression-level=%d", zstdCompressionLevel),
}
}
return []string{"--compress"}
case tools.MysqlVersion57:
return []string{"--compress"}
default:
@@ -448,40 +457,22 @@ func (uc *CreateMysqlBackupUsecase) setupBackupEncryption(
return storageWriter, nil, metadata, nil
}
salt, err := backup_encryption.GenerateSalt()
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to generate salt: %w", err)
}
nonce, err := backup_encryption.GenerateNonce()
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to generate nonce: %w", err)
}
masterKey, err := uc.secretKeyService.GetSecretKey()
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to get master key: %w", err)
}
encWriter, err := backup_encryption.NewEncryptionWriter(
storageWriter,
masterKey,
backupID,
salt,
nonce,
)
encSetup, err := backup_encryption.SetupEncryptionWriter(storageWriter, masterKey, backupID)
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to create encrypting writer: %w", err)
return nil, nil, metadata, err
}
saltBase64 := base64.StdEncoding.EncodeToString(salt)
nonceBase64 := base64.StdEncoding.EncodeToString(nonce)
metadata.EncryptionSalt = &saltBase64
metadata.EncryptionIV = &nonceBase64
metadata.EncryptionSalt = &encSetup.SaltBase64
metadata.EncryptionIV = &encSetup.NonceBase64
metadata.Encryption = backups_config.BackupEncryptionEncrypted
uc.logger.Info("Encryption enabled for backup", "backupId", backupID)
return encWriter, encWriter, metadata, nil
return encSetup.Writer, encSetup.Writer, metadata, nil
}
func (uc *CreateMysqlBackupUsecase) cleanupOnCancellation(
@@ -604,6 +595,15 @@ func (uc *CreateMysqlBackupUsecase) handleConnectionErrors(stderrStr string) err
)
}
if containsIgnoreCase(stderrStr, "compression algorithm") ||
containsIgnoreCase(stderrStr, "2066") {
return fmt.Errorf(
"MySQL connection failed due to unsupported compression algorithm. "+
"Try re-saving the database connection to re-detect compression support. stderr: %s",
stderrStr,
)
}
if containsIgnoreCase(stderrStr, "unknown database") {
return fmt.Errorf(
"MySQL database does not exist. stderr: %s",

View File

@@ -2,7 +2,6 @@ package usecases_postgresql
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
@@ -492,40 +491,22 @@ func (uc *CreatePostgresqlBackupUsecase) setupBackupEncryption(
return storageWriter, nil, metadata, nil
}
salt, err := backup_encryption.GenerateSalt()
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to generate salt: %w", err)
}
nonce, err := backup_encryption.GenerateNonce()
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to generate nonce: %w", err)
}
masterKey, err := uc.secretKeyService.GetSecretKey()
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to get master key: %w", err)
}
encWriter, err := backup_encryption.NewEncryptionWriter(
storageWriter,
masterKey,
backupID,
salt,
nonce,
)
encSetup, err := backup_encryption.SetupEncryptionWriter(storageWriter, masterKey, backupID)
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to create encrypting writer: %w", err)
return nil, nil, metadata, err
}
saltBase64 := base64.StdEncoding.EncodeToString(salt)
nonceBase64 := base64.StdEncoding.EncodeToString(nonce)
metadata.EncryptionSalt = &saltBase64
metadata.EncryptionIV = &nonceBase64
metadata.EncryptionSalt = &encSetup.SaltBase64
metadata.EncryptionIV = &encSetup.NonceBase64
metadata.Encryption = backups_config.BackupEncryptionEncrypted
uc.logger.Info("Encryption enabled for backup", "backupId", backupID)
return encWriter, encWriter, metadata, nil
return encSetup.Writer, encSetup.Writer, metadata, nil
}
func (uc *CreatePostgresqlBackupUsecase) cleanupOnCancellation(

View File

@@ -29,6 +29,11 @@ func (c *DatabaseController) RegisterRoutes(router *gin.RouterGroup) {
router.GET("/databases/notifier/:id/databases-count", c.CountDatabasesByNotifier)
router.POST("/databases/is-readonly", c.IsUserReadOnly)
router.POST("/databases/create-readonly-user", c.CreateReadOnlyUser)
router.POST("/databases/:id/regenerate-token", c.RegenerateAgentToken)
}
func (c *DatabaseController) RegisterPublicRoutes(router *gin.RouterGroup) {
router.POST("/databases/verify-token", c.VerifyAgentToken)
}
// CreateDatabase
@@ -438,3 +443,61 @@ func (c *DatabaseController) CreateReadOnlyUser(ctx *gin.Context) {
Password: password,
})
}
// RegenerateAgentToken
// @Summary Regenerate agent token for a database
// @Description Generate a new agent token for the database. The token is returned once and stored as a hash.
// @Tags databases
// @Produce json
// @Security BearerAuth
// @Param id path string true "Database ID"
// @Success 200 {object} map[string]string
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Router /databases/{id}/regenerate-token [post]
func (c *DatabaseController) RegenerateAgentToken(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
return
}
token, err := c.databaseService.RegenerateAgentToken(user, id)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, gin.H{"token": token})
}
// VerifyAgentToken
// @Summary Verify agent token
// @Description Verify that a given agent token is valid for any database
// @Tags databases
// @Accept json
// @Produce json
// @Param request body VerifyAgentTokenRequest true "Token to verify"
// @Success 200 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Router /databases/verify-token [post]
func (c *DatabaseController) VerifyAgentToken(ctx *gin.Context) {
var request VerifyAgentTokenRequest
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := c.databaseService.VerifyAgentToken(request.Token); err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
ctx.JSON(http.StatusOK, gin.H{"message": "token is valid"})
}

View File

@@ -13,10 +13,13 @@ import (
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
"databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/databases/databases/mariadb"
"databasus-backend/internal/features/databases/databases/mongodb"
"databasus-backend/internal/features/databases/databases/postgresql"
users_enums "databasus-backend/internal/features/users/enums"
users_middleware "databasus-backend/internal/features/users/middleware"
users_services "databasus-backend/internal/features/users/services"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
@@ -144,6 +147,66 @@ func Test_CreateDatabase_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testin
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
func Test_CreateDatabase_WalV1Type_NoConnectionFieldsRequired(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
request := Database{
Name: "Test WAL Database",
WorkspaceID: &workspace.ID,
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
BackupType: postgresql.PostgresBackupTypeWalV1,
CpuCount: 1,
},
}
var response Database
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/databases/create",
"Bearer "+owner.Token,
request,
http.StatusCreated,
&response,
)
defer RemoveTestDatabase(&response)
assert.Equal(t, "Test WAL Database", response.Name)
assert.NotEqual(t, uuid.Nil, response.ID)
}
func Test_CreateDatabase_PgDumpType_ConnectionFieldsRequired(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
request := Database{
Name: "Test PG_DUMP Database",
WorkspaceID: &workspace.ID,
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
BackupType: postgresql.PostgresBackupTypePgDump,
CpuCount: 1,
},
}
testResp := test_utils.MakePostRequest(
t,
router,
"/api/v1/databases/create",
"Bearer "+owner.Token,
request,
http.StatusBadRequest,
)
assert.Contains(t, string(testResp.Body), "host is required")
}
func Test_UpdateDatabase_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
@@ -256,6 +319,52 @@ func Test_UpdateDatabase_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testin
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
func Test_UpdateDatabase_WhenDatabaseTypeChanged_ReturnsBadRequest(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(database)
database.Type = DatabaseTypeMysql
testResp := test_utils.MakePostRequest(
t,
router,
"/api/v1/databases/update",
"Bearer "+owner.Token,
database,
http.StatusBadRequest,
)
assert.Contains(t, string(testResp.Body), "database type cannot be changed")
}
func Test_UpdateDatabase_WhenBackupTypeChanged_ReturnsBadRequest(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(database)
database.Postgresql.BackupType = postgresql.PostgresBackupTypeWalV1
testResp := test_utils.MakePostRequest(
t,
router,
"/api/v1/databases/update",
"Bearer "+owner.Token,
database,
http.StatusBadRequest,
)
assert.Contains(t, string(testResp.Body), "backup type cannot be changed")
}
func Test_DeleteDatabase_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
@@ -1050,6 +1159,87 @@ func Test_TestConnection_PermissionsEnforced(t *testing.T) {
}
}
func Test_RegenerateAgentToken_ReturnsToken(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(database)
var response map[string]string
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/databases/"+database.ID.String()+"/regenerate-token",
"Bearer "+owner.Token,
nil,
http.StatusOK,
&response,
)
assert.NotEmpty(t, response["token"])
assert.Len(t, response["token"], 32)
var updatedDatabase Database
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/databases/"+database.ID.String(),
"Bearer "+owner.Token,
http.StatusOK,
&updatedDatabase,
)
assert.True(t, updatedDatabase.IsAgentTokenGenerated)
}
func Test_VerifyAgentToken_WithValidToken_Succeeds(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(database)
var regenerateResponse map[string]string
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/databases/"+database.ID.String()+"/regenerate-token",
"Bearer "+owner.Token,
nil,
http.StatusOK,
&regenerateResponse,
)
token := regenerateResponse["token"]
assert.NotEmpty(t, token)
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/verify-token",
"",
VerifyAgentTokenRequest{Token: token},
)
assert.Equal(t, http.StatusOK, w.Code)
}
func Test_VerifyAgentToken_WithInvalidToken_Returns401(t *testing.T) {
router := createTestRouter()
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/verify-token",
"",
VerifyAgentTokenRequest{Token: "invalidtoken00000000000000000000"},
)
assert.Equal(t, http.StatusUnauthorized, w.Code)
}
func createTestDatabaseViaAPI(
name string,
workspaceID uuid.UUID,
@@ -1101,11 +1291,20 @@ func createTestDatabaseViaAPI(
}
func createTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
GetDatabaseController(),
)
gin.SetMode(gin.TestMode)
router := gin.New()
v1 := router.Group("/api/v1")
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
workspaces_controllers.GetWorkspaceController().RegisterRoutes(protected.(*gin.RouterGroup))
workspaces_controllers.GetMembershipController().RegisterRoutes(protected.(*gin.RouterGroup))
GetDatabaseController().RegisterRoutes(protected.(*gin.RouterGroup))
GetDatabaseController().RegisterPublicRoutes(v1)
audit_logs.SetupDependencies()
return router
}
@@ -1118,13 +1317,14 @@ func getTestPostgresConfig() *postgresql.PostgresqlDatabase {
testDbName := "testdb"
return &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
CpuCount: 1,
BackupType: postgresql.PostgresBackupTypePgDump,
Version: tools.PostgresqlVersion16,
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
CpuCount: 1,
}
}

View File

@@ -25,13 +25,14 @@ type MysqlDatabase struct {
Version tools.MysqlVersion `json:"version" gorm:"type:text;not null"`
Host string `json:"host" gorm:"type:text;not null"`
Port int `json:"port" gorm:"type:int;not null"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
Database *string `json:"database" gorm:"type:text"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
Privileges string `json:"privileges" gorm:"column:privileges;type:text;not null;default:''"`
Host string `json:"host" gorm:"type:text;not null"`
Port int `json:"port" gorm:"type:int;not null"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
Database *string `json:"database" gorm:"type:text"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
Privileges string `json:"privileges" gorm:"column:privileges;type:text;not null;default:''"`
IsZstdSupported bool `json:"isZstdSupported" gorm:"column:is_zstd_supported;type:boolean;not null;default:true"`
}
func (m *MysqlDatabase) TableName() string {
@@ -102,6 +103,7 @@ func (m *MysqlDatabase) TestConnection(
return err
}
m.Privileges = privileges
m.IsZstdSupported = detectZstdSupport(ctx, db)
if err := checkBackupPermissions(m.Privileges); err != nil {
return err
@@ -125,6 +127,7 @@ func (m *MysqlDatabase) Update(incoming *MysqlDatabase) {
m.Database = incoming.Database
m.IsHttps = incoming.IsHttps
m.Privileges = incoming.Privileges
m.IsZstdSupported = incoming.IsZstdSupported
if incoming.Password != "" {
m.Password = incoming.Password
@@ -185,6 +188,7 @@ func (m *MysqlDatabase) PopulateDbData(
return err
}
m.Privileges = privileges
m.IsZstdSupported = detectZstdSupport(ctx, db)
return nil
}
@@ -223,6 +227,7 @@ func (m *MysqlDatabase) PopulateVersion(
return err
}
m.Version = detectedVersion
m.IsZstdSupported = detectZstdSupport(ctx, db)
return nil
}
@@ -575,6 +580,22 @@ func checkBackupPermissions(privileges string) error {
return nil
}
// detectZstdSupport checks if the MySQL server supports zstd network compression.
// The protocol_compression_algorithms variable was introduced in MySQL 8.0.18.
// Managed MySQL providers (e.g. PlanetScale) may not support zstd even on 8.0+.
func detectZstdSupport(ctx context.Context, db *sql.DB) bool {
var varName, value string
err := db.QueryRowContext(ctx,
"SHOW VARIABLES LIKE 'protocol_compression_algorithms'",
).Scan(&varName, &value)
if err != nil {
return false
}
return strings.Contains(strings.ToLower(value), "zstd")
}
func decryptPasswordIfNeeded(
password string,
encryptor encryption.FieldEncryptor,

View File

@@ -177,6 +177,38 @@ func Test_TestConnection_SufficientPermissions_Success(t *testing.T) {
}
}
func Test_TestConnection_DetectsZstdSupport(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MysqlVersion
port string
isExpectZstd bool
}{
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port, false},
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port, true},
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port, true},
{"MySQL 9", tools.MysqlVersion9, env.TestMysql90Port, true},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMysqlContainer(t, tc.port, tc.version)
defer container.DB.Close()
mysqlModel := createMysqlModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err := mysqlModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
assert.Equal(t, tc.isExpectZstd, mysqlModel.IsZstdSupported,
"IsZstdSupported mismatch for %s", tc.name)
})
}
}
func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
env := config.GetEnv()
cases := []struct {

View File

@@ -2,6 +2,7 @@ package postgresql
import (
"context"
"databasus-backend/internal/config"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/tools"
"errors"
@@ -17,6 +18,13 @@ import (
"gorm.io/gorm"
)
type PostgresBackupType string
const (
PostgresBackupTypePgDump PostgresBackupType = "PG_DUMP"
PostgresBackupTypeWalV1 PostgresBackupType = "WAL_V1"
)
type PostgresqlDatabase struct {
ID uuid.UUID `json:"id" gorm:"primaryKey;type:uuid;default:gen_random_uuid()"`
@@ -24,11 +32,13 @@ type PostgresqlDatabase struct {
Version tools.PostgresqlVersion `json:"version" gorm:"type:text;not null"`
// connection data
Host string `json:"host" gorm:"type:text;not null"`
Port int `json:"port" gorm:"type:int;not null"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
BackupType PostgresBackupType `json:"backupType" gorm:"column:backup_type;type:text;not null;default:'PG_DUMP'"`
// connection data — required for PG_DUMP, optional for WAL_V1
Host string `json:"host" gorm:"type:text"`
Port int `json:"port" gorm:"type:int"`
Username string `json:"username" gorm:"type:text"`
Password string `json:"password" gorm:"type:text"`
Database *string `json:"database" gorm:"type:text"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
@@ -66,20 +76,30 @@ func (p *PostgresqlDatabase) AfterFind(_ *gorm.DB) error {
}
func (p *PostgresqlDatabase) Validate() error {
if p.Host == "" {
return errors.New("host is required")
if p.BackupType == "" {
p.BackupType = PostgresBackupTypePgDump
}
if p.Port == 0 {
return errors.New("port is required")
if p.BackupType == PostgresBackupTypePgDump && config.GetEnv().IsCloud {
return errors.New("PG_DUMP backup type is not supported in cloud mode")
}
if p.Username == "" {
return errors.New("username is required")
}
if p.BackupType == PostgresBackupTypePgDump {
if p.Host == "" {
return errors.New("host is required")
}
if p.Password == "" {
return errors.New("password is required")
if p.Port == 0 {
return errors.New("port is required")
}
if p.Username == "" {
return errors.New("username is required")
}
if p.Password == "" {
return errors.New("password is required")
}
}
if p.CpuCount <= 0 {
@@ -90,7 +110,7 @@ func (p *PostgresqlDatabase) Validate() error {
// Databasus runs an internal PostgreSQL instance that should not be backed up through the UI
// because it would expose internal metadata to non-system administrators.
// To properly backup Databasus, see: https://databasus.com/faq#backup-databasus
if p.Database != nil && *p.Database != "" {
if p.BackupType == PostgresBackupTypePgDump && p.Database != nil && *p.Database != "" {
localhostHosts := []string{
"localhost",
"127.0.0.1",
@@ -130,6 +150,10 @@ func (p *PostgresqlDatabase) TestConnection(
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error {
if p.BackupType == PostgresBackupTypeWalV1 {
return errors.New("test connection is not supported for WAL backup type")
}
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
@@ -144,7 +168,21 @@ func (p *PostgresqlDatabase) HideSensitiveData() {
p.Password = ""
}
func (p *PostgresqlDatabase) ValidateUpdate(old *PostgresqlDatabase) error {
// BackupType cannot be changed after creation — the full backup structure
// (WAL hierarchy, storage files, cleanup logic) is built around
// the type chosen at creation time. Automatically migrating this state is
// error-prone; it is safer for the user to create a new database and
// remove the old one.
if old.BackupType != p.BackupType {
return errors.New("backup type cannot be changed; create a new database instead")
}
return nil
}
func (p *PostgresqlDatabase) Update(incoming *PostgresqlDatabase) {
p.BackupType = incoming.BackupType
p.Version = incoming.Version
p.Host = incoming.Host
p.Port = incoming.Port
@@ -181,6 +219,10 @@ func (p *PostgresqlDatabase) PopulateDbData(
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error {
if p.BackupType == PostgresBackupTypeWalV1 {
return nil
}
return p.PopulateVersion(logger, encryptor, databaseID)
}
@@ -243,6 +285,10 @@ func (p *PostgresqlDatabase) IsUserReadOnly(
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) (bool, []string, error) {
if p.BackupType == PostgresBackupTypeWalV1 {
return false, nil, errors.New("read-only check is not supported for WAL backup type")
}
password, err := decryptPasswordIfNeeded(p.Password, encryptor, databaseID)
if err != nil {
return false, nil, fmt.Errorf("failed to decrypt password: %w", err)
@@ -415,6 +461,10 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) (string, string, error) {
if p.BackupType == PostgresBackupTypeWalV1 {
return "", "", errors.New("read-only user creation is not supported for WAL backup type")
}
password, err := decryptPasswordIfNeeded(p.Password, encryptor, databaseID)
if err != nil {
return "", "", fmt.Errorf("failed to decrypt password: %w", err)

View File

@@ -9,3 +9,7 @@ type IsReadOnlyResponse struct {
IsReadOnly bool `json:"isReadOnly"`
Privileges []string `json:"privileges"`
}
type VerifyAgentTokenRequest struct {
Token string `json:"token" binding:"required"`
}

View File

@@ -37,6 +37,9 @@ type Database struct {
LastBackupErrorMessage *string `json:"lastBackupErrorMessage,omitempty" gorm:"column:last_backup_error_message;type:text"`
HealthStatus *HealthStatus `json:"healthStatus" gorm:"column:health_status;type:text;not null"`
AgentToken *string `json:"-" gorm:"column:agent_token;type:text"`
IsAgentTokenGenerated bool `json:"isAgentTokenGenerated" gorm:"column:is_agent_token_generated;not null;default:false"`
}
func (d *Database) Validate() error {
@@ -71,8 +74,19 @@ func (d *Database) Validate() error {
}
func (d *Database) ValidateUpdate(old, new Database) error {
// Database type cannot be changed after creation — the entire backup
// structure (storage files, schedulers, WAL hierarchy, etc.) is tied to
// the type at creation time. Recreating that state automatically is
// error-prone; it is safer for the user to create a new database and
// remove the old one.
if old.Type != new.Type {
return errors.New("database type is not allowed to change")
return errors.New("database type cannot be changed; create a new database instead")
}
if old.Type == DatabaseTypePostgres && old.Postgresql != nil && new.Postgresql != nil {
if err := new.Postgresql.ValidateUpdate(old.Postgresql); err != nil {
return err
}
}
return nil

View File

@@ -244,6 +244,18 @@ func (r *DatabaseRepository) GetAllDatabases() ([]*Database, error) {
return databases, nil
}
func (r *DatabaseRepository) FindByAgentTokenHash(hash string) (*Database, error) {
var database Database
if err := storage.GetDb().
Where("agent_token = ?", hash).
First(&database).Error; err != nil {
return nil, err
}
return &database, nil
}
func (r *DatabaseRepository) GetDatabasesIDsByNotifierID(
notifierID uuid.UUID,
) ([]uuid.UUID, error) {

View File

@@ -2,9 +2,11 @@ package databases
import (
"context"
"crypto/sha256"
"errors"
"fmt"
"log/slog"
"strings"
"time"
"databasus-backend/internal/config"
@@ -87,21 +89,8 @@ func (s *DatabaseService) CreateDatabase(
return nil, fmt.Errorf("failed to auto-detect database data: %w", err)
}
if config.GetEnv().IsCloud {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
isReadOnly, permissions, err := database.IsUserReadOnly(ctx, s.logger, s.fieldEncryptor)
if err != nil {
return nil, fmt.Errorf("failed to verify user permissions: %w", err)
}
if !isReadOnly {
return nil, fmt.Errorf(
"in cloud mode, only read-only database users are allowed (user has permissions: %v)",
permissions,
)
}
if err := s.verifyReadOnlyUserIfNeeded(database); err != nil {
return nil, err
}
if err := database.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
@@ -171,25 +160,8 @@ func (s *DatabaseService) UpdateDatabase(
return fmt.Errorf("failed to auto-detect database data: %w", err)
}
if config.GetEnv().IsCloud {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
isReadOnly, permissions, err := existingDatabase.IsUserReadOnly(
ctx,
s.logger,
s.fieldEncryptor,
)
if err != nil {
return fmt.Errorf("failed to verify user permissions: %w", err)
}
if !isReadOnly {
return fmt.Errorf(
"in cloud mode, only read-only database users are allowed (user has permissions: %v)",
permissions,
)
}
if err := s.verifyReadOnlyUserIfNeeded(existingDatabase); err != nil {
return err
}
oldName := existingDatabase.Name
@@ -485,6 +457,7 @@ func (s *DatabaseService) CopyDatabase(
newDatabase.Postgresql = &postgresql.PostgresqlDatabase{
ID: uuid.Nil,
DatabaseID: nil,
BackupType: existingDatabase.Postgresql.BackupType,
Version: existingDatabase.Postgresql.Version,
Host: existingDatabase.Postgresql.Host,
Port: existingDatabase.Postgresql.Port,
@@ -638,6 +611,71 @@ func (s *DatabaseService) SetHealthStatus(
return nil
}
func (s *DatabaseService) RegenerateAgentToken(
user *users_models.User,
databaseID uuid.UUID,
) (string, error) {
database, err := s.dbRepository.FindByID(databaseID)
if err != nil {
return "", err
}
if database.WorkspaceID == nil {
return "", errors.New("cannot regenerate token for database without workspace")
}
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user)
if err != nil {
return "", err
}
if !canManage {
return "", errors.New(
"insufficient permissions to regenerate agent token for this database",
)
}
plainToken := strings.ReplaceAll(uuid.New().String(), "-", "")
tokenHash := hashAgentToken(plainToken)
database.AgentToken = &tokenHash
database.IsAgentTokenGenerated = true
_, err = s.dbRepository.Save(database)
if err != nil {
return "", err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Agent token regenerated for database: %s", database.Name),
&user.ID,
database.WorkspaceID,
)
return plainToken, nil
}
func (s *DatabaseService) VerifyAgentToken(token string) error {
hash := hashAgentToken(token)
_, err := s.dbRepository.FindByAgentTokenHash(hash)
if err != nil {
return errors.New("invalid token")
}
return nil
}
func (s *DatabaseService) GetDatabaseByAgentToken(token string) (*Database, error) {
hash := hashAgentToken(token)
partial, err := s.dbRepository.FindByAgentTokenHash(hash)
if err != nil {
return nil, errors.New("invalid agent token")
}
return s.dbRepository.FindByID(partial.ID)
}
func (s *DatabaseService) OnBeforeWorkspaceDeletion(workspaceID uuid.UUID) error {
databases, err := s.dbRepository.FindByWorkspaceID(workspaceID)
if err != nil {
@@ -809,3 +847,36 @@ func (s *DatabaseService) CreateReadOnlyUser(
return username, password, nil
}
func (s *DatabaseService) verifyReadOnlyUserIfNeeded(database *Database) error {
if !config.GetEnv().IsCloud {
return nil
}
if database.Postgresql != nil &&
database.Postgresql.BackupType == postgresql.PostgresBackupTypeWalV1 {
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
isReadOnly, permissions, err := database.IsUserReadOnly(ctx, s.logger, s.fieldEncryptor)
if err != nil {
return fmt.Errorf("failed to verify user permissions: %w", err)
}
if !isReadOnly {
return fmt.Errorf(
"in cloud mode, only read-only database users are allowed (user has permissions: %v)",
permissions,
)
}
return nil
}
func hashAgentToken(token string) string {
hash := sha256.Sum256([]byte(token))
return fmt.Sprintf("%x", hash)
}

View File

@@ -79,6 +79,38 @@ func (i *Interval) ShouldTriggerBackup(now time.Time, lastBackupTime *time.Time)
}
}
// NextTriggerTime computes the next time a backup should trigger based on the interval and last backup time.
// Returns nil when a backup is due immediately (no previous backup exists).
func (i *Interval) NextTriggerTime(now time.Time, lastBackupTime *time.Time) *time.Time {
if lastBackupTime == nil {
return nil
}
switch i.Interval {
case IntervalHourly:
next := lastBackupTime.Add(time.Hour)
return &next
case IntervalDaily:
next := i.nextDailyTrigger(now)
return &next
case IntervalWeekly:
next := i.nextWeeklyTrigger(now)
return &next
case IntervalMonthly:
next := i.nextMonthlyTrigger(now)
return &next
case IntervalCron:
return i.nextCronTrigger(*lastBackupTime)
default:
return nil
}
}
func (i *Interval) Copy() *Interval {
return &Interval{
ID: uuid.Nil,
@@ -240,6 +272,99 @@ func (i *Interval) shouldTriggerCron(now, lastBackup time.Time) bool {
return now.After(nextAfterLastBackup) || now.Equal(nextAfterLastBackup)
}
func (i *Interval) nextDailyTrigger(now time.Time) time.Time {
t, err := time.Parse("15:04", *i.TimeOfDay)
if err != nil {
return now
}
todaySlot := time.Date(
now.Year(), now.Month(), now.Day(),
t.Hour(), t.Minute(), 0, 0, now.Location(),
)
if now.Before(todaySlot) {
return todaySlot
}
return todaySlot.AddDate(0, 0, 1)
}
func (i *Interval) nextWeeklyTrigger(now time.Time) time.Time {
targetWd := time.Weekday(0)
if i.Weekday != nil {
targetWd = time.Weekday(*i.Weekday)
}
startOfWeek := getStartOfWeek(now)
var daysFromMonday int
if targetWd == time.Sunday {
daysFromMonday = 6
} else {
daysFromMonday = int(targetWd) - 1
}
targetThisWeek := startOfWeek.AddDate(0, 0, daysFromMonday)
if i.TimeOfDay != nil {
t, err := time.Parse("15:04", *i.TimeOfDay)
if err == nil {
targetThisWeek = time.Date(
targetThisWeek.Year(), targetThisWeek.Month(), targetThisWeek.Day(),
t.Hour(), t.Minute(), 0, 0, targetThisWeek.Location(),
)
}
}
if now.Before(targetThisWeek) {
return targetThisWeek
}
return targetThisWeek.AddDate(0, 0, 7)
}
func (i *Interval) nextMonthlyTrigger(now time.Time) time.Time {
day := 1
if i.DayOfMonth != nil {
day = *i.DayOfMonth
}
targetThisMonth := time.Date(now.Year(), now.Month(), day, 0, 0, 0, 0, now.Location())
if i.TimeOfDay != nil {
t, err := time.Parse("15:04", *i.TimeOfDay)
if err == nil {
targetThisMonth = time.Date(
targetThisMonth.Year(), targetThisMonth.Month(), targetThisMonth.Day(),
t.Hour(), t.Minute(), 0, 0, targetThisMonth.Location(),
)
}
}
if now.Before(targetThisMonth) {
return targetThisMonth
}
return targetThisMonth.AddDate(0, 1, 0)
}
func (i *Interval) nextCronTrigger(lastBackup time.Time) *time.Time {
if i.CronExpression == nil || *i.CronExpression == "" {
return nil
}
parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
schedule, err := parser.Parse(*i.CronExpression)
if err != nil {
return nil
}
next := schedule.Next(lastBackup)
return &next
}
func (i *Interval) validateCronExpression(expr string) error {
parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
_, err := parser.Parse(expr)

View File

@@ -721,3 +721,265 @@ func TestInterval_Validate(t *testing.T) {
assert.NoError(t, err)
})
}
func TestInterval_NextTriggerTime_NilLastBackup(t *testing.T) {
now := time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC)
t.Run("Hourly with nil lastBackup returns nil", func(t *testing.T) {
interval := &Interval{ID: uuid.New(), Interval: IntervalHourly}
result := interval.NextTriggerTime(now, nil)
assert.Nil(t, result)
})
t.Run("Daily with nil lastBackup returns nil", func(t *testing.T) {
timeOfDay := "09:00"
interval := &Interval{ID: uuid.New(), Interval: IntervalDaily, TimeOfDay: &timeOfDay}
result := interval.NextTriggerTime(now, nil)
assert.Nil(t, result)
})
t.Run("Weekly with nil lastBackup returns nil", func(t *testing.T) {
timeOfDay := "15:00"
weekday := 3
interval := &Interval{
ID: uuid.New(),
Interval: IntervalWeekly,
TimeOfDay: &timeOfDay,
Weekday: &weekday,
}
result := interval.NextTriggerTime(now, nil)
assert.Nil(t, result)
})
t.Run("Monthly with nil lastBackup returns nil", func(t *testing.T) {
timeOfDay := "08:00"
dayOfMonth := 10
interval := &Interval{
ID: uuid.New(),
Interval: IntervalMonthly,
TimeOfDay: &timeOfDay,
DayOfMonth: &dayOfMonth,
}
result := interval.NextTriggerTime(now, nil)
assert.Nil(t, result)
})
t.Run("Cron with nil lastBackup returns nil", func(t *testing.T) {
cronExpr := "0 2 * * *"
interval := &Interval{ID: uuid.New(), Interval: IntervalCron, CronExpression: &cronExpr}
result := interval.NextTriggerTime(now, nil)
assert.Nil(t, result)
})
}
func TestInterval_NextTriggerTime_Hourly(t *testing.T) {
interval := &Interval{ID: uuid.New(), Interval: IntervalHourly}
now := time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC)
t.Run("Returns lastBackup + 1 hour", func(t *testing.T) {
lastBackup := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.NotNil(t, result)
assert.Equal(t, time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC), *result)
})
t.Run("Returns future time when last backup was recent", func(t *testing.T) {
lastBackup := time.Date(2024, 1, 15, 11, 30, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.NotNil(t, result)
assert.Equal(t, time.Date(2024, 1, 15, 12, 30, 0, 0, time.UTC), *result)
})
}
func TestInterval_NextTriggerTime_Daily(t *testing.T) {
timeOfDay := "09:00"
interval := &Interval{ID: uuid.New(), Interval: IntervalDaily, TimeOfDay: &timeOfDay}
t.Run("Before today's slot: returns today's slot", func(t *testing.T) {
now := time.Date(2024, 1, 15, 8, 0, 0, 0, time.UTC)
lastBackup := time.Date(2024, 1, 14, 9, 0, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.NotNil(t, result)
assert.Equal(t, time.Date(2024, 1, 15, 9, 0, 0, 0, time.UTC), *result)
})
t.Run("After today's slot: returns tomorrow's slot", func(t *testing.T) {
now := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC)
lastBackup := time.Date(2024, 1, 15, 9, 0, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.NotNil(t, result)
assert.Equal(t, time.Date(2024, 1, 16, 9, 0, 0, 0, time.UTC), *result)
})
t.Run("Exactly at today's slot: returns tomorrow's slot", func(t *testing.T) {
now := time.Date(2024, 1, 15, 9, 0, 0, 0, time.UTC)
lastBackup := time.Date(2024, 1, 14, 9, 0, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.NotNil(t, result)
assert.Equal(t, time.Date(2024, 1, 16, 9, 0, 0, 0, time.UTC), *result)
})
}
func TestInterval_NextTriggerTime_Weekly(t *testing.T) {
timeOfDay := "15:00"
weekday := 3 // Wednesday
interval := &Interval{
ID: uuid.New(),
Interval: IntervalWeekly,
TimeOfDay: &timeOfDay,
Weekday: &weekday,
}
t.Run("Before this week's target: returns this week's target", func(t *testing.T) {
// Tuesday Jan 16, 2024
now := time.Date(2024, 1, 16, 10, 0, 0, 0, time.UTC)
lastBackup := time.Date(2024, 1, 10, 15, 0, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.NotNil(t, result)
// Wednesday Jan 17 at 15:00
assert.Equal(t, time.Date(2024, 1, 17, 15, 0, 0, 0, time.UTC), *result)
})
t.Run("After this week's target: returns next week's target", func(t *testing.T) {
// Thursday Jan 18, 2024
now := time.Date(2024, 1, 18, 10, 0, 0, 0, time.UTC)
lastBackup := time.Date(2024, 1, 17, 15, 0, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.NotNil(t, result)
// Next Wednesday Jan 24 at 15:00
assert.Equal(t, time.Date(2024, 1, 24, 15, 0, 0, 0, time.UTC), *result)
})
t.Run("Friday interval: returns correct target", func(t *testing.T) {
fridayTimeOfDay := "00:00"
fridayWeekday := 5 // Friday
fridayInterval := &Interval{
ID: uuid.New(),
Interval: IntervalWeekly,
TimeOfDay: &fridayTimeOfDay,
Weekday: &fridayWeekday,
}
// Wednesday Jan 17, 2024
now := time.Date(2024, 1, 17, 10, 0, 0, 0, time.UTC)
lastBackup := time.Date(2024, 1, 12, 0, 0, 0, 0, time.UTC)
result := fridayInterval.NextTriggerTime(now, &lastBackup)
assert.NotNil(t, result)
// Friday Jan 19 at 00:00
assert.Equal(t, time.Date(2024, 1, 19, 0, 0, 0, 0, time.UTC), *result)
})
}
func TestInterval_NextTriggerTime_Monthly(t *testing.T) {
timeOfDay := "08:00"
dayOfMonth := 10
interval := &Interval{
ID: uuid.New(),
Interval: IntervalMonthly,
TimeOfDay: &timeOfDay,
DayOfMonth: &dayOfMonth,
}
t.Run("Before this month's target: returns this month's target", func(t *testing.T) {
now := time.Date(2024, 1, 5, 10, 0, 0, 0, time.UTC)
lastBackup := time.Date(2023, 12, 10, 8, 0, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.NotNil(t, result)
assert.Equal(t, time.Date(2024, 1, 10, 8, 0, 0, 0, time.UTC), *result)
})
t.Run("After this month's target: returns next month's target", func(t *testing.T) {
now := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC)
lastBackup := time.Date(2024, 1, 10, 8, 0, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.NotNil(t, result)
assert.Equal(t, time.Date(2024, 2, 10, 8, 0, 0, 0, time.UTC), *result)
})
t.Run("Exactly at this month's target: returns next month's target", func(t *testing.T) {
now := time.Date(2024, 1, 10, 8, 0, 0, 0, time.UTC)
lastBackup := time.Date(2023, 12, 10, 8, 0, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.NotNil(t, result)
assert.Equal(t, time.Date(2024, 2, 10, 8, 0, 0, 0, time.UTC), *result)
})
}
func TestInterval_NextTriggerTime_Cron(t *testing.T) {
t.Run("Daily cron: returns next trigger after lastBackup", func(t *testing.T) {
cronExpr := "0 2 * * *" // Daily at 2:00 AM
interval := &Interval{ID: uuid.New(), Interval: IntervalCron, CronExpression: &cronExpr}
now := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC)
lastBackup := time.Date(2024, 1, 14, 2, 0, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.NotNil(t, result)
assert.Equal(t, time.Date(2024, 1, 15, 2, 0, 0, 0, time.UTC), *result)
})
t.Run("Complex cron: 1st and 15th at 4:30", func(t *testing.T) {
cronExpr := "30 4 1,15 * *"
interval := &Interval{ID: uuid.New(), Interval: IntervalCron, CronExpression: &cronExpr}
now := time.Date(2024, 1, 10, 10, 0, 0, 0, time.UTC)
lastBackup := time.Date(2024, 1, 1, 4, 30, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.NotNil(t, result)
assert.Equal(t, time.Date(2024, 1, 15, 4, 30, 0, 0, time.UTC), *result)
})
t.Run("Invalid cron expression returns nil", func(t *testing.T) {
invalidCron := "invalid cron"
interval := &Interval{ID: uuid.New(), Interval: IntervalCron, CronExpression: &invalidCron}
now := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC)
lastBackup := time.Date(2024, 1, 14, 10, 0, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.Nil(t, result)
})
t.Run("Empty cron expression returns nil", func(t *testing.T) {
emptyCron := ""
interval := &Interval{ID: uuid.New(), Interval: IntervalCron, CronExpression: &emptyCron}
now := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC)
lastBackup := time.Date(2024, 1, 14, 10, 0, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.Nil(t, result)
})
t.Run("Nil cron expression returns nil", func(t *testing.T) {
interval := &Interval{ID: uuid.New(), Interval: IntervalCron, CronExpression: nil}
now := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC)
lastBackup := time.Date(2024, 1, 14, 10, 0, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.Nil(t, result)
})
}
func TestInterval_NextTriggerTime_UnknownInterval(t *testing.T) {
interval := &Interval{ID: uuid.New(), Interval: IntervalType("UNKNOWN")}
now := time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC)
lastBackup := time.Date(2024, 1, 14, 12, 0, 0, 0, time.UTC)
result := interval.NextTriggerTime(now, &lastBackup)
assert.Nil(t, result)
}

View File

@@ -23,13 +23,14 @@ const (
)
type EmailNotifier struct {
NotifierID uuid.UUID `json:"notifierId" gorm:"primaryKey;type:uuid;column:notifier_id"`
TargetEmail string `json:"targetEmail" gorm:"not null;type:varchar(255);column:target_email"`
SMTPHost string `json:"smtpHost" gorm:"not null;type:varchar(255);column:smtp_host"`
SMTPPort int `json:"smtpPort" gorm:"not null;column:smtp_port"`
SMTPUser string `json:"smtpUser" gorm:"type:varchar(255);column:smtp_user"`
SMTPPassword string `json:"smtpPassword" gorm:"type:varchar(255);column:smtp_password"`
From string `json:"from" gorm:"type:varchar(255);column:from_email"`
NotifierID uuid.UUID `json:"notifierId" gorm:"primaryKey;type:uuid;column:notifier_id"`
TargetEmail string `json:"targetEmail" gorm:"not null;type:varchar(255);column:target_email"`
SMTPHost string `json:"smtpHost" gorm:"not null;type:varchar(255);column:smtp_host"`
SMTPPort int `json:"smtpPort" gorm:"not null;column:smtp_port"`
SMTPUser string `json:"smtpUser" gorm:"type:varchar(255);column:smtp_user"`
SMTPPassword string `json:"smtpPassword" gorm:"type:varchar(255);column:smtp_password"`
From string `json:"from" gorm:"type:varchar(255);column:from_email"`
IsInsecureSkipVerify bool `json:"isInsecureSkipVerify" gorm:"default:false;column:is_insecure_skip_verify"`
}
func (e *EmailNotifier) TableName() string {
@@ -99,6 +100,7 @@ func (e *EmailNotifier) Update(incoming *EmailNotifier) {
e.SMTPPort = incoming.SMTPPort
e.SMTPUser = incoming.SMTPUser
e.From = incoming.From
e.IsInsecureSkipVerify = incoming.IsInsecureSkipVerify
if incoming.SMTPPassword != "" {
e.SMTPPassword = incoming.SMTPPassword
@@ -198,7 +200,10 @@ func (e *EmailNotifier) sendStartTLS(
func (e *EmailNotifier) createImplicitTLSClient() (*smtp.Client, func(), error) {
addr := net.JoinHostPort(e.SMTPHost, fmt.Sprintf("%d", e.SMTPPort))
tlsConfig := &tls.Config{ServerName: e.SMTPHost}
tlsConfig := &tls.Config{
ServerName: e.SMTPHost,
InsecureSkipVerify: e.IsInsecureSkipVerify,
}
dialer := &net.Dialer{Timeout: DefaultTimeout}
conn, err := tls.DialWithDialer(dialer, "tcp", addr, tlsConfig)
@@ -237,7 +242,10 @@ func (e *EmailNotifier) createStartTLSClient() (*smtp.Client, func(), error) {
}
if ok, _ := client.Extension("STARTTLS"); ok {
if err := client.StartTLS(&tls.Config{ServerName: e.SMTPHost}); err != nil {
if err := client.StartTLS(&tls.Config{
ServerName: e.SMTPHost,
InsecureSkipVerify: e.IsInsecureSkipVerify,
}); err != nil {
_ = client.Quit()
_ = conn.Close()
return nil, nil, fmt.Errorf("STARTTLS failed: %w", err)

View File

@@ -18,7 +18,7 @@ import (
env_config "databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
backups_controllers "databasus-backend/internal/features/backups/backups/controllers"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -440,7 +440,7 @@ func Test_CancelRestore_InProgressRestore_SuccessfullyCancelled(t *testing.T) {
}()
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
backup := backups.CreateTestBackup(database.ID, storage.ID)
backup := backups_controllers.CreateTestBackup(database.ID, storage.ID)
mockUsecase := &restoring.MockBlockingRestoreUsecase{
StartedChan: make(chan bool, 1),

View File

@@ -5,8 +5,8 @@ import (
"sync/atomic"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
"databasus-backend/internal/features/backups/backups/backuping"
backups_services "databasus-backend/internal/features/backups/backups/services"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/disk"
@@ -21,7 +21,7 @@ import (
var restoreRepository = &restores_core.RestoreRepository{}
var restoreService = &RestoreService{
backups.GetBackupService(),
backups_services.GetBackupService(),
restoreRepository,
storages.GetStorageService(),
backups_config.GetBackupConfigService(),
@@ -51,7 +51,7 @@ func SetupDependencies() {
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
backups.GetBackupService().AddBackupRemoveListener(restoreService)
backups_services.GetBackupService().AddBackupRemoveListener(restoreService)
backuping.GetBackupCleaner().AddBackupRemoveListener(restoreService)
isSetup.Store(true)

View File

@@ -7,7 +7,7 @@ import (
"github.com/google/uuid"
"databasus-backend/internal/features/backups/backups"
backups_services "databasus-backend/internal/features/backups/backups/services"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
restores_core "databasus-backend/internal/features/restores/core"
@@ -39,37 +39,37 @@ var restoreDatabaseCache = cache_utils.NewCacheUtil[RestoreDatabaseCache](
var restoreCancelManager = tasks_cancellation.GetTaskCancelManager()
var restorerNode = &RestorerNode{
nodeID: uuid.New(),
databaseService: databases.GetDatabaseService(),
backupService: backups.GetBackupService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
restoreRepository: restoreRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
restoreNodesRegistry: restoreNodesRegistry,
logger: logger.GetLogger(),
restoreBackupUsecase: usecases.GetRestoreBackupUsecase(),
cacheUtil: restoreDatabaseCache,
restoreCancelManager: restoreCancelManager,
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
uuid.New(),
databases.GetDatabaseService(),
backups_services.GetBackupService(),
encryption.GetFieldEncryptor(),
restoreRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
restoreNodesRegistry,
logger.GetLogger(),
usecases.GetRestoreBackupUsecase(),
restoreDatabaseCache,
restoreCancelManager,
time.Time{},
sync.Once{},
atomic.Bool{},
}
var restoresScheduler = &RestoresScheduler{
restoreRepository: restoreRepository,
backupService: backups.GetBackupService(),
storageService: storages.GetStorageService(),
backupConfigService: backups_config.GetBackupConfigService(),
restoreNodesRegistry: restoreNodesRegistry,
lastCheckTime: time.Now().UTC(),
logger: logger.GetLogger(),
restoreToNodeRelations: make(map[uuid.UUID]RestoreToNodeRelation),
restorerNode: restorerNode,
cacheUtil: restoreDatabaseCache,
completionSubscriptionID: uuid.Nil,
runOnce: sync.Once{},
hasRun: atomic.Bool{},
restoreRepository,
backups_services.GetBackupService(),
storages.GetStorageService(),
backups_config.GetBackupConfigService(),
restoreNodesRegistry,
time.Now().UTC(),
logger.GetLogger(),
make(map[uuid.UUID]RestoreToNodeRelation),
restorerNode,
restoreDatabaseCache,
uuid.Nil,
sync.Once{},
atomic.Bool{},
}
func GetRestoresScheduler() *RestoresScheduler {

View File

@@ -13,7 +13,7 @@ import (
"github.com/google/uuid"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_services "databasus-backend/internal/features/backups/backups/services"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
restores_core "databasus-backend/internal/features/restores/core"
@@ -32,7 +32,7 @@ type RestorerNode struct {
nodeID uuid.UUID
databaseService *databases.DatabaseService
backupService *backups.BackupService
backupService *backups_services.BackupService
fieldEncryptor util_encryption.FieldEncryptor
restoreRepository *restores_core.RestoreRepository
backupConfigService *backups_config.BackupConfigService

View File

@@ -7,7 +7,7 @@ import (
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_controllers "databasus-backend/internal/features/backups/backups/controllers"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -58,7 +58,7 @@ func Test_MakeRestore_WhenCacheMissed_RestoreFails(t *testing.T) {
cache_utils.ClearAllCache()
}()
backup := backups.CreateTestBackup(database.ID, storage.ID)
backup := backups_controllers.CreateTestBackup(database.ID, storage.ID)
// Create restore but DON'T cache DB credentials
// Also don't set embedded DB fields to avoid schema issues
@@ -126,7 +126,7 @@ func Test_MakeRestore_WhenTaskStarts_CacheDeletedImmediately(t *testing.T) {
cache_utils.ClearAllCache()
}()
backup := backups.CreateTestBackup(database.ID, storage.ID)
backup := backups_controllers.CreateTestBackup(database.ID, storage.ID)
// Create restore with cached DB credentials
// Don't set embedded DB fields in the restore model itself

View File

@@ -11,7 +11,7 @@ import (
"github.com/google/uuid"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_services "databasus-backend/internal/features/backups/backups/services"
backups_config "databasus-backend/internal/features/backups/config"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/storages"
@@ -26,7 +26,7 @@ const (
type RestoresScheduler struct {
restoreRepository *restores_core.RestoreRepository
backupService *backups.BackupService
backupService *backups_services.BackupService
storageService *storages.StorageService
backupConfigService *backups_config.BackupConfigService
restoreNodesRegistry *RestoreNodesRegistry

View File

@@ -5,7 +5,7 @@ import (
"time"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_controllers "databasus-backend/internal/features/backups/backups/controllers"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -68,7 +68,7 @@ func Test_CheckDeadNodesAndFailRestores_NodeDies_FailsRestoreAndCleansUpRegistry
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
// Create a test backup
backup := backups.CreateTestBackup(database.ID, storage.ID)
backup := backups_controllers.CreateTestBackup(database.ID, storage.ID)
var err error
// Register mock node without subscribing to restores (simulates node crash after registration)
@@ -171,7 +171,7 @@ func Test_OnRestoreCompleted_TaskIsNotRestore_SkipsProcessing(t *testing.T) {
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
// Create a test backup
backup := backups.CreateTestBackup(database.ID, storage.ID)
backup := backups_controllers.CreateTestBackup(database.ID, storage.ID)
// Register mock node
mockNodeID = uuid.New()
@@ -357,7 +357,7 @@ func Test_FailRestoresInProgress_SchedulerStarts_UpdatesStatus(t *testing.T) {
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
// Create a test backup
backup := backups.CreateTestBackup(database.ID, storage.ID)
backup := backups_controllers.CreateTestBackup(database.ID, storage.ID)
// Create two in-progress restores that should be failed on scheduler restart
restore1 := &restores_core.Restore{
@@ -465,7 +465,7 @@ func Test_StartRestore_RestoreCompletes_DecrementsActiveTaskCount(t *testing.T)
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
// Create a test backup
backup := backups.CreateTestBackup(database.ID, storage.ID)
backup := backups_controllers.CreateTestBackup(database.ID, storage.ID)
// Get initial active task count
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
@@ -566,7 +566,7 @@ func Test_StartRestore_RestoreFails_DecrementsActiveTaskCount(t *testing.T) {
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
// Create a test backup
backup := backups.CreateTestBackup(database.ID, storage.ID)
backup := backups_controllers.CreateTestBackup(database.ID, storage.ID)
// Get initial active task count
stats, err := restoreNodesRegistry.GetRestoreNodesStats()
@@ -664,7 +664,7 @@ func Test_StartRestore_CredentialsStoredEncryptedInCache(t *testing.T) {
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
// Create a test backup
backup := backups.CreateTestBackup(database.ID, storage.ID)
backup := backups_controllers.CreateTestBackup(database.ID, storage.ID)
// Register mock node so scheduler can assign restore to it
mockNodeID = uuid.New()
@@ -779,7 +779,7 @@ func Test_StartRestore_CredentialsRemovedAfterRestoreStarts(t *testing.T) {
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
// Create a test backup
backup := backups.CreateTestBackup(database.ID, storage.ID)
backup := backups_controllers.CreateTestBackup(database.ID, storage.ID)
// Create restore with credentials
plaintextPassword := "test_password_456"

View File

@@ -12,8 +12,8 @@ import (
"github.com/google/uuid"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_services "databasus-backend/internal/features/backups/backups/services"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/postgresql"
@@ -40,48 +40,48 @@ func CreateTestRouter() *gin.Engine {
func CreateTestRestorerNode() *RestorerNode {
return &RestorerNode{
nodeID: uuid.New(),
databaseService: databases.GetDatabaseService(),
backupService: backups.GetBackupService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
restoreRepository: restoreRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
restoreNodesRegistry: restoreNodesRegistry,
logger: logger.GetLogger(),
restoreBackupUsecase: usecases.GetRestoreBackupUsecase(),
cacheUtil: restoreDatabaseCache,
restoreCancelManager: tasks_cancellation.GetTaskCancelManager(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
uuid.New(),
databases.GetDatabaseService(),
backups_services.GetBackupService(),
encryption.GetFieldEncryptor(),
restoreRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
restoreNodesRegistry,
logger.GetLogger(),
usecases.GetRestoreBackupUsecase(),
restoreDatabaseCache,
tasks_cancellation.GetTaskCancelManager(),
time.Time{},
sync.Once{},
atomic.Bool{},
}
}
func CreateTestRestorerNodeWithUsecase(usecase restores_core.RestoreBackupUsecase) *RestorerNode {
return &RestorerNode{
nodeID: uuid.New(),
databaseService: databases.GetDatabaseService(),
backupService: backups.GetBackupService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
restoreRepository: restoreRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
restoreNodesRegistry: restoreNodesRegistry,
logger: logger.GetLogger(),
restoreBackupUsecase: usecase,
cacheUtil: restoreDatabaseCache,
restoreCancelManager: tasks_cancellation.GetTaskCancelManager(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
uuid.New(),
databases.GetDatabaseService(),
backups_services.GetBackupService(),
encryption.GetFieldEncryptor(),
restoreRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
restoreNodesRegistry,
logger.GetLogger(),
usecase,
restoreDatabaseCache,
tasks_cancellation.GetTaskCancelManager(),
time.Time{},
sync.Once{},
atomic.Bool{},
}
}
func CreateTestRestoresScheduler() *RestoresScheduler {
return &RestoresScheduler{
restoreRepository,
backups.GetBackupService(),
backups_services.GetBackupService(),
storages.GetStorageService(),
backups_config.GetBackupConfigService(),
restoreNodesRegistry,

View File

@@ -3,8 +3,8 @@ package restores
import (
"databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_services "databasus-backend/internal/features/backups/backups/services"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/disk"
@@ -26,7 +26,7 @@ import (
)
type RestoreService struct {
backupService *backups.BackupService
backupService *backups_services.BackupService
restoreRepository *restores_core.RestoreRepository
storageService *storages.StorageService
backupConfigService *backups_config.BackupConfigService

View File

@@ -8,7 +8,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"databasus-backend/internal/features/backups/backups"
backups_controllers "databasus-backend/internal/features/backups/backups/controllers"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/restores/restoring"
@@ -22,12 +22,12 @@ func CreateTestRouter() *gin.Engine {
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
backups_config.GetBackupConfigController(),
backups.GetBackupController(),
backups_controllers.GetBackupController(),
GetRestoreController(),
)
v1 := router.Group("/api/v1")
backups.GetBackupController().RegisterPublicRoutes(v1)
backups_controllers.GetBackupController().RegisterPublicRoutes(v1)
return router
}

View File

@@ -70,6 +70,10 @@ func (uc *RestoreMariadbBackupUsecase) Execute(
"--verbose",
}
if !config.GetEnv().IsCloud {
args = append(args, "--max-allowed-packet=1G")
}
if mdb.IsHttps {
args = append(args, "--ssl")
args = append(args, "--skip-ssl-verify-server-cert")

View File

@@ -70,6 +70,10 @@ func (uc *RestoreMysqlBackupUsecase) Execute(
"--verbose",
}
if !config.GetEnv().IsCloud {
args = append(args, "--max-allowed-packet=1G")
}
if my.IsHttps {
args = append(args, "--ssl-mode=REQUIRED")
}

View File

@@ -47,14 +47,15 @@ func (l *LocalStorage) SaveFile(
logger.Info("Starting to save file to local storage", "fileName", fileName)
tempFilePath := filepath.Join(config.GetEnv().TempFolder, fileName)
err := files_utils.EnsureDirectories([]string{
config.GetEnv().TempFolder,
filepath.Dir(tempFilePath),
})
if err != nil {
return fmt.Errorf("failed to ensure directories: %w", err)
}
tempFilePath := filepath.Join(config.GetEnv().TempFolder, fileName)
logger.Debug("Creating temp file", "fileName", fileName, "tempPath", tempFilePath)
tempFile, err := os.Create(tempFilePath)
@@ -101,6 +102,10 @@ func (l *LocalStorage) SaveFile(
finalPath,
)
if err = files_utils.EnsureDirectories([]string{filepath.Dir(finalPath)}); err != nil {
return fmt.Errorf("failed to ensure final directory: %w", err)
}
// Move the file from temp to backups directory
if err = os.Rename(tempFilePath, finalPath); err != nil {
logger.Error(

View File

@@ -0,0 +1,30 @@
package system_version
import (
"net/http"
"os"
"github.com/gin-gonic/gin"
)
type VersionController struct{}
func (c *VersionController) RegisterRoutes(router *gin.RouterGroup) {
router.GET("/system/version", c.GetVersion)
}
// GetVersion
// @Summary Get application version
// @Description Returns the current application version
// @Tags system/version
// @Produce json
// @Success 200 {object} VersionResponse
// @Router /system/version [get]
func (c *VersionController) GetVersion(ctx *gin.Context) {
version := os.Getenv("APP_VERSION")
if version == "" {
version = "dev"
}
ctx.JSON(http.StatusOK, VersionResponse{Version: version})
}

View File

@@ -0,0 +1,7 @@
package system_version
var versionController = &VersionController{}
func GetVersionController() *VersionController {
return versionController
}

View File

@@ -0,0 +1,5 @@
package system_version
type VersionResponse struct {
Version string `json:"version"`
}

View File

@@ -8,8 +8,8 @@ import (
"time"
"databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
"databasus-backend/internal/features/backups/backups/backuping"
backups_services "databasus-backend/internal/features/backups/backups/services"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
@@ -26,8 +26,8 @@ func Test_SetupDependencies_CalledTwice_LogsWarning(t *testing.T) {
audit_logs.SetupDependencies()
audit_logs.SetupDependencies()
backups.SetupDependencies()
backups.SetupDependencies()
backups_services.SetupDependencies()
backups_services.SetupDependencies()
backups_config.SetupDependencies()
backups_config.SetupDependencies()

View File

@@ -17,8 +17,9 @@ import (
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
backups_controllers "databasus-backend/internal/features/backups/backups/controllers"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_dto "databasus-backend/internal/features/backups/backups/dto"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
pgtypes "databasus-backend/internal/features/databases/databases/postgresql"
@@ -1234,7 +1235,7 @@ func createTestRouter() *gin.Engine {
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
backups_config.GetBackupConfigController(),
backups.GetBackupController(),
backups_controllers.GetBackupController(),
restores.GetRestoreController(),
)
return router
@@ -1255,7 +1256,7 @@ func waitForBackupCompletion(
t.Fatalf("Timeout waiting for backup completion after %v", timeout)
}
var response backups.GetBackupsResponse
var response backups_dto.GetBackupsResponse
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
@@ -1431,7 +1432,7 @@ func createBackupViaAPI(
databaseID uuid.UUID,
token string,
) {
request := backups.MakeBackupRequest{DatabaseID: databaseID}
request := backups_dto.MakeBackupRequest{DatabaseID: databaseID}
test_utils.MakePostRequest(
t,
router,

View File

@@ -0,0 +1,125 @@
package wal
import (
"encoding/hex"
"errors"
"fmt"
"strings"
)
const (
segmentNameLen = 24
timelineLen = 8
logLen = 8
segLen = 8
defaultSegmentSizeBytes = 16 * 1024 * 1024 // 16 MB
)
// WalCalculator performs WAL segment name arithmetic for a given WAL segment size.
//
// WAL segment name format: TTTTTTTTLLLLLLLLSSSSSSSS (24 hex chars)
// - TT: timeline (8 hex digits)
// - LL: log file / XLogId (8 hex digits)
// - SS: segment within log file (8 hex digits)
//
// segmentsPerXLogId = 0x100000000 / segmentSizeBytes
// Increment SS; if SS >= segmentsPerXLogId → SS = 0, LL++
type WalCalculator struct {
segmentSizeBytes int64
segmentsPerXLogId uint64
}
// NewWalCalculator creates a WalCalculator for the given WAL segment size in bytes.
// Pass 0 or a negative value to use the PostgreSQL default of 16 MB.
func NewWalCalculator(segmentSizeBytes int64) *WalCalculator {
if segmentSizeBytes <= 0 {
segmentSizeBytes = defaultSegmentSizeBytes
}
return &WalCalculator{
segmentSizeBytes: segmentSizeBytes,
segmentsPerXLogId: uint64(0x100000000) / uint64(segmentSizeBytes),
}
}
// NextSegment computes the next WAL segment name after current.
// Returns an error if current is not a valid 24-character hex WAL segment name.
func (c *WalCalculator) NextSegment(current string) (string, error) {
if !c.IsValidSegmentName(current) {
return "", fmt.Errorf("invalid WAL segment name: %q", current)
}
timeline := current[:timelineLen]
logHex := current[timelineLen : timelineLen+logLen]
segHex := current[timelineLen+logLen:]
logVal, err := parseHex32(logHex)
if err != nil {
return "", fmt.Errorf("parse log part of %q: %w", current, err)
}
segVal, err := parseHex32(segHex)
if err != nil {
return "", fmt.Errorf("parse seg part of %q: %w", current, err)
}
segVal++
if uint64(segVal) >= c.segmentsPerXLogId {
segVal = 0
logVal++
}
return fmt.Sprintf("%s%08X%08X", strings.ToUpper(timeline), logVal, segVal), nil
}
// IsValidSegmentName returns true if name is a 24-character uppercase (or lowercase) hex string
// representing a valid WAL segment name.
func (c *WalCalculator) IsValidSegmentName(name string) bool {
if len(name) != segmentNameLen {
return false
}
_, err := hex.DecodeString(name)
return err == nil
}
// Compare compares two WAL segment names a and b by their numeric value.
// Returns -1 if a < b, 0 if a == b, 1 if a > b.
// Both names must be valid; returns an error otherwise.
// Timeline is compared first, then log file, then segment number.
func (c *WalCalculator) Compare(a, b string) (int, error) {
if !c.IsValidSegmentName(a) {
return 0, fmt.Errorf("invalid WAL segment name: %q", a)
}
if !c.IsValidSegmentName(b) {
return 0, fmt.Errorf("invalid WAL segment name: %q", b)
}
// Fixed-width uppercase hex: lexicographic order equals numeric order.
aUpper := strings.ToUpper(a)
bUpper := strings.ToUpper(b)
if aUpper < bUpper {
return -1, nil
}
if aUpper > bUpper {
return 1, nil
}
return 0, nil
}
func parseHex32(s string) (uint32, error) {
if len(s) != 8 {
return 0, errors.New("expected 8 hex characters")
}
b, err := hex.DecodeString(s)
if err != nil {
return 0, err
}
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3]), nil
}

View File

@@ -0,0 +1,221 @@
package wal
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
mb1 = 1 * 1024 * 1024
mb16 = 16 * 1024 * 1024
mb64 = 64 * 1024 * 1024
)
// NextSegment — no wrap
func Test_WalCalculator_NextSegment_DefaultSegmentSize_NoWrap(t *testing.T) {
calc := NewWalCalculator(mb16)
cases := []struct {
input string
expected string
}{
{"000000010000000100000000", "000000010000000100000001"},
{"000000010000000100000027", "000000010000000100000028"},
{"0000000100000001000000AB", "0000000100000001000000AC"},
{"0000000100000001000000FE", "0000000100000001000000FF"},
}
for _, tc := range cases {
result, err := calc.NextSegment(tc.input)
require.NoError(t, err)
assert.Equal(t, tc.expected, result, "input=%s", tc.input)
}
}
// NextSegment — wrap at 0x100 (256) for 16 MB segments
func Test_WalCalculator_NextSegment_DefaultSegmentSize_WrapsAt256(t *testing.T) {
calc := NewWalCalculator(mb16)
// SS=0xFF → SS=0x00, LL++
result, err := calc.NextSegment("0000000100000001000000FF")
require.NoError(t, err)
assert.Equal(t, "000000010000000200000000", result)
result, err = calc.NextSegment("0000000200000005000000FF")
require.NoError(t, err)
assert.Equal(t, "000000020000000600000000", result)
}
// NextSegment — wrap at 0x1000 (4096) for 1 MB segments
func Test_WalCalculator_NextSegment_1MbSegmentSize_WrapsAt4096(t *testing.T) {
calc := NewWalCalculator(mb1)
// segmentsPerXLogId = 0x100000000 / (1*1024*1024) = 4096 = 0x1000
// Last valid SS = 0x00000FFF
result, err := calc.NextSegment("000000010000000100000FFE")
require.NoError(t, err)
assert.Equal(t, "000000010000000100000FFF", result)
// wrap: SS=0x0FFF → SS=0, LL++
result, err = calc.NextSegment("000000010000000100000FFF")
require.NoError(t, err)
assert.Equal(t, "000000010000000200000000", result)
}
// NextSegment — wrap at 0x40 (64) for 64 MB segments
func Test_WalCalculator_NextSegment_64MbSegmentSize_WrapsAt64(t *testing.T) {
calc := NewWalCalculator(mb64)
// segmentsPerXLogId = 0x100000000 / (64*1024*1024) = 64 = 0x40
// Last valid SS = 0x0000003F
result, err := calc.NextSegment("00000001000000010000003E")
require.NoError(t, err)
assert.Equal(t, "00000001000000010000003F", result)
// wrap: SS=0x3F → SS=0, LL++
result, err = calc.NextSegment("00000001000000010000003F")
require.NoError(t, err)
assert.Equal(t, "000000010000000200000000", result)
}
// NextSegment — log file increment on wrap
func Test_WalCalculator_NextSegment_IncrementsLog_OnSegmentWrap(t *testing.T) {
calc := NewWalCalculator(mb16)
// LL=0x00000001, wraps to LL=0x00000002
result, err := calc.NextSegment("0000000100000001000000FF")
require.NoError(t, err)
assert.Equal(t, "000000010000000200000000", result)
// LL=0x0000000F, wraps to LL=0x00000010
result, err = calc.NextSegment("00000001000000FF000000FF")
require.NoError(t, err)
assert.Equal(t, "000000010000010000000000", result)
}
// NextSegment — timeline is preserved
func Test_WalCalculator_NextSegment_TimelinePreserved(t *testing.T) {
calc := NewWalCalculator(mb16)
result, err := calc.NextSegment("000000030000000100000005")
require.NoError(t, err)
assert.Equal(t, "000000030000000100000006", result)
}
// NextSegment — invalid input
func Test_WalCalculator_NextSegment_InvalidName_ReturnsError(t *testing.T) {
calc := NewWalCalculator(mb16)
cases := []string{
"",
"00000001000000010000000", // 23 chars
"0000000100000001000000001", // 25 chars
"00000001000000010000000G", // non-hex char
"short",
}
for _, name := range cases {
_, err := calc.NextSegment(name)
assert.Error(t, err, "expected error for input %q", name)
}
}
// IsValidSegmentName
func Test_WalCalculator_IsValidSegmentName_ValidName_ReturnsTrue(t *testing.T) {
calc := NewWalCalculator(mb16)
valid := []string{
"000000010000000100000000",
"0000000100000001000000FF",
"FFFFFFFFFFFFFFFFFFFFFFFF",
"000000010000000100000027",
"0000000200000005000000AB",
}
for _, name := range valid {
assert.True(t, calc.IsValidSegmentName(name), "expected valid for %q", name)
}
}
func Test_WalCalculator_IsValidSegmentName_TooShort_ReturnsFalse(t *testing.T) {
calc := NewWalCalculator(mb16)
assert.False(t, calc.IsValidSegmentName("00000001000000010000000")) // 23 chars
assert.False(t, calc.IsValidSegmentName(""))
}
func Test_WalCalculator_IsValidSegmentName_TooLong_ReturnsFalse(t *testing.T) {
calc := NewWalCalculator(mb16)
assert.False(t, calc.IsValidSegmentName("0000000100000001000000001")) // 25 chars
}
func Test_WalCalculator_IsValidSegmentName_NonHex_ReturnsFalse(t *testing.T) {
calc := NewWalCalculator(mb16)
assert.False(t, calc.IsValidSegmentName("00000001000000010000000G"))
assert.False(t, calc.IsValidSegmentName("00000001000000010000000Z"))
assert.False(t, calc.IsValidSegmentName("000000010000000100000 00"))
}
// Compare
func Test_WalCalculator_Compare_ReturnsCorrectOrdering(t *testing.T) {
calc := NewWalCalculator(mb16)
cases := []struct {
a string
b string
expected int
}{
// equal
{"000000010000000100000001", "000000010000000100000001", 0},
// segment ordering
{"000000010000000100000001", "000000010000000100000002", -1},
{"000000010000000100000002", "000000010000000100000001", 1},
// log ordering
{"000000010000000100000000", "000000010000000200000000", -1},
{"000000010000000200000000", "000000010000000100000000", 1},
// timeline ordering
{"000000010000000100000000", "000000020000000100000000", -1},
{"000000020000000100000000", "000000010000000100000000", 1},
// across wrap boundary: log 1, seg 255 < log 2, seg 0
{"0000000100000001000000FF", "000000010000000200000000", -1},
}
for _, tc := range cases {
result, err := calc.Compare(tc.a, tc.b)
require.NoError(t, err)
assert.Equal(t, tc.expected, result, "Compare(%s, %s)", tc.a, tc.b)
}
}
func Test_WalCalculator_Compare_InvalidInput_ReturnsError(t *testing.T) {
calc := NewWalCalculator(mb16)
_, err := calc.Compare("invalid", "000000010000000100000001")
assert.Error(t, err)
_, err = calc.Compare("000000010000000100000001", "invalid")
assert.Error(t, err)
}
// Default segment size via zero input
func Test_WalCalculator_NewWalCalculator_ZeroSize_UsesDefault16MB(t *testing.T) {
calc := NewWalCalculator(0)
assert.Equal(t, int64(mb16), calc.segmentSizeBytes)
assert.Equal(t, uint64(256), calc.segmentsPerXLogId)
}

View File

@@ -0,0 +1,72 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE databases
ADD COLUMN agent_token TEXT,
ADD COLUMN is_agent_token_generated BOOLEAN NOT NULL DEFAULT FALSE;
CREATE UNIQUE INDEX idx_databases_agent_token ON databases (agent_token) WHERE agent_token IS NOT NULL;
ALTER TABLE postgresql_databases
ADD COLUMN backup_type TEXT NOT NULL DEFAULT 'PG_DUMP';
ALTER TABLE postgresql_databases
ALTER COLUMN host DROP NOT NULL,
ALTER COLUMN port DROP NOT NULL,
ALTER COLUMN username DROP NOT NULL,
ALTER COLUMN password DROP NOT NULL;
ALTER TABLE backups
ADD COLUMN pg_wal_backup_type TEXT,
ADD COLUMN pg_wal_start_segment TEXT,
ADD COLUMN pg_wal_stop_segment TEXT,
ADD COLUMN pg_version TEXT,
ADD COLUMN pg_wal_segment_name TEXT;
CREATE INDEX idx_backups_pg_wal_segment_name
ON backups (database_id, pg_wal_segment_name)
WHERE pg_wal_segment_name IS NOT NULL;
CREATE INDEX idx_backups_pg_wal_backup_type_created
ON backups (database_id, pg_wal_backup_type, created_at DESC)
WHERE pg_wal_backup_type IS NOT NULL;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
DROP INDEX IF EXISTS idx_backups_pg_wal_segment_name;
DROP INDEX IF EXISTS idx_backups_pg_wal_backup_type_created;
ALTER TABLE backups
DROP COLUMN pg_wal_backup_type,
DROP COLUMN pg_wal_start_segment,
DROP COLUMN pg_wal_stop_segment,
DROP COLUMN pg_version,
DROP COLUMN pg_wal_segment_name;
UPDATE postgresql_databases
SET host = 'localhost' WHERE host IS NULL OR host = '';
UPDATE postgresql_databases
SET port = 5432 WHERE port IS NULL OR port = 0;
UPDATE postgresql_databases
SET username = 'postgres' WHERE username IS NULL OR username = '';
UPDATE postgresql_databases
SET password = 'stubpassword' WHERE password IS NULL OR password = '';
ALTER TABLE postgresql_databases
DROP COLUMN backup_type;
ALTER TABLE postgresql_databases
ALTER COLUMN host SET NOT NULL,
ALTER COLUMN port SET NOT NULL,
ALTER COLUMN username SET NOT NULL,
ALTER COLUMN password SET NOT NULL;
DROP INDEX IF EXISTS idx_databases_agent_token;
ALTER TABLE databases
DROP COLUMN agent_token,
DROP COLUMN is_agent_token_generated;
-- +goose StatementEnd

View File

@@ -0,0 +1,11 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE email_notifiers
ADD COLUMN is_insecure_skip_verify BOOLEAN NOT NULL DEFAULT FALSE;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
ALTER TABLE email_notifiers
DROP COLUMN is_insecure_skip_verify;
-- +goose StatementEnd

View File

@@ -0,0 +1,11 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE mysql_databases
ADD COLUMN is_zstd_supported BOOLEAN NOT NULL DEFAULT TRUE;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
ALTER TABLE mysql_databases
DROP COLUMN is_zstd_supported;
-- +goose StatementEnd

View File

@@ -3,6 +3,8 @@ import { useEffect, useState } from 'react';
import { BrowserRouter, Route } from 'react-router';
import { Routes } from 'react-router';
import { useVersionCheck } from './shared/hooks/useVersionCheck';
import { userApi } from './entity/users';
import { AuthPageComponent } from './pages/AuthPageComponent';
import { OAuthCallbackPage } from './pages/OAuthCallbackPage';
@@ -14,6 +16,8 @@ function AppContent() {
const [isAuthorized, setIsAuthorized] = useState(false);
const { resolvedTheme } = useTheme();
useVersionCheck();
useEffect(() => {
const isAuthorized = userApi.isAuthorized();
setIsAuthorized(isAuthorized);

View File

@@ -4,6 +4,7 @@ interface RuntimeConfig {
GOOGLE_CLIENT_ID?: string;
IS_EMAIL_CONFIGURED?: string;
CLOUDFLARE_TURNSTILE_SITE_KEY?: string;
CONTAINER_ARCH?: string;
}
declare global {
@@ -45,6 +46,10 @@ export const CLOUDFLARE_TURNSTILE_SITE_KEY =
import.meta.env.VITE_CLOUDFLARE_TURNSTILE_SITE_KEY ||
'';
const archMap: Record<string, string> = { amd64: 'x64', arm64: 'arm64' };
const rawArch = window.__RUNTIME_CONFIG__?.CONTAINER_ARCH || 'unknown';
export const CONTAINER_ARCH = archMap[rawArch] || rawArch;
export function getOAuthRedirectUri(): string {
return `${window.location.origin}/auth/callback`;
}

View File

@@ -5,4 +5,5 @@ export interface EmailNotifier {
smtpUser: string;
smtpPassword: string;
from: string;
isInsecureSkipVerify: boolean;
}

View File

@@ -0,0 +1,14 @@
import { getApplicationServer } from '../../../constants';
import type { VersionResponse } from '../model/VersionResponse';
export const systemApi = {
async getVersion(): Promise<VersionResponse> {
const response = await fetch(`${getApplicationServer()}/api/v1/system/version`);
if (!response.ok) {
throw new Error(`Failed to fetch version: ${response.status}`);
}
return response.json();
},
};

View File

@@ -0,0 +1,2 @@
export { systemApi } from './api/systemApi';
export { type VersionResponse } from './model/VersionResponse';

View File

@@ -0,0 +1,3 @@
export type VersionResponse = {
version: string;
};

View File

@@ -112,6 +112,7 @@ export function EditNotifierComponent({
smtpUser: '',
smtpPassword: '',
from: '',
isInsecureSkipVerify: false,
};
}

View File

@@ -1,5 +1,6 @@
import { InfoCircleOutlined } from '@ant-design/icons';
import { Input, Tooltip } from 'antd';
import { DownOutlined, InfoCircleOutlined, UpOutlined } from '@ant-design/icons';
import { Checkbox, Input, Tooltip } from 'antd';
import { useState } from 'react';
import type { Notifier } from '../../../../../entity/notifiers';
@@ -10,6 +11,9 @@ interface Props {
}
export function EditEmailNotifierComponent({ notifier, setNotifier, setUnsaved }: Props) {
const hasAdvancedValues = !!notifier?.emailNotifier?.isInsecureSkipVerify;
const [showAdvanced, setShowAdvanced] = useState(hasAdvancedValues);
return (
<>
<div className="mb-1 flex w-full flex-col items-start sm:flex-row sm:items-center">
@@ -163,6 +167,53 @@ export function EditEmailNotifierComponent({ notifier, setNotifier, setUnsaved }
</Tooltip>
</div>
</div>
<div className="mt-4 mb-3 flex items-center">
<div
className="flex cursor-pointer items-center text-sm text-blue-600 hover:text-blue-800"
onClick={() => setShowAdvanced(!showAdvanced)}
>
<span className="mr-2">Advanced settings</span>
{showAdvanced ? (
<UpOutlined style={{ fontSize: '12px' }} />
) : (
<DownOutlined style={{ fontSize: '12px' }} />
)}
</div>
</div>
{showAdvanced && (
<div className="mb-1 flex w-full flex-col items-start sm:flex-row sm:items-center">
<div className="mb-1 min-w-[150px] sm:mb-0">Skip TLS verify</div>
<div className="flex items-center">
<Checkbox
checked={notifier?.emailNotifier?.isInsecureSkipVerify || false}
onChange={(e) => {
if (!notifier?.emailNotifier) return;
setNotifier({
...notifier,
emailNotifier: {
...notifier.emailNotifier,
isInsecureSkipVerify: e.target.checked,
},
});
setUnsaved();
}}
>
Skip TLS
</Checkbox>
<Tooltip
className="cursor-pointer"
title="Skip TLS certificate verification. Enable this if your SMTP server uses a self-signed certificate. Warning: this reduces security."
>
<InfoCircleOutlined className="ml-2" style={{ color: 'gray' }} />
</Tooltip>
</div>
</div>
)}
</>
);
}

View File

@@ -36,6 +36,13 @@ export function ShowEmailNotifierComponent({ notifier }: Props) {
<div className="min-w-[110px]">From</div>
{notifier?.emailNotifier?.from || '(auto)'}
</div>
{notifier?.emailNotifier?.isInsecureSkipVerify && (
<div className="mb-1 flex items-center">
<div className="min-w-[110px]">Skip TLS</div>
Enabled
</div>
)}
</>
);
}

View File

@@ -0,0 +1,40 @@
import { useEffect } from 'react';
import { APP_VERSION } from '../../constants';
import { systemApi } from '../../entity/system';
const VERSION_CHECK_INTERVAL_MS = 5 * 60 * 1000;
const RELOAD_COOLDOWN_MS = 10 * 1000;
const LAST_RELOAD_KEY = 'lastVersionReload';
export function useVersionCheck() {
useEffect(() => {
if (APP_VERSION === 'dev') {
return;
}
const checkVersion = async () => {
try {
const { version } = await systemApi.getVersion();
if (version && version !== APP_VERSION) {
const lastReload = Number(localStorage.getItem(LAST_RELOAD_KEY) || '0');
if (Date.now() - lastReload < RELOAD_COOLDOWN_MS) {
return;
}
localStorage.setItem(LAST_RELOAD_KEY, String(Date.now()));
window.location.reload();
}
} catch {
// Silently ignore errors — network issues shouldn't break the app
}
};
checkVersion();
const interval = setInterval(checkVersion, VERSION_CHECK_INTERVAL_MS);
return () => clearInterval(interval);
}, []);
}

View File

@@ -2,7 +2,7 @@ import { LoadingOutlined, MenuOutlined } from '@ant-design/icons';
import { App, Button, Spin, Tooltip } from 'antd';
import { useEffect, useState } from 'react';
import { APP_VERSION } from '../../constants';
import { APP_VERSION, CONTAINER_ARCH } from '../../constants';
import { type DiskUsage, diskApi } from '../../entity/disk';
import {
type UserProfile,
@@ -365,6 +365,8 @@ export const MainScreenComponent = () => {
<div className="absolute bottom-1 left-2 mb-[0px] hidden text-sm text-gray-400 md:block">
v{APP_VERSION}
<br />
{CONTAINER_ARCH}
</div>
</div>
)}